diff --git a/.gitignore b/.gitignore index 0449bb94d..12e0533b1 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,4 @@ example.py /gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu /gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu /gptqmodel_offload/ +/gptqmodel_ext/machete/generated/ diff --git a/MANIFEST.in b/MANIFEST.in index 9efddd22b..a18a3e2c1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,9 +3,12 @@ recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/machete *.h *.hpp *.cuh *.cu *.cpp *.py +recursive-include gptqmodel_ext/cutlass_extensions *.h *.hpp *.cuh *.cu *.cpp *.py recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp include gptqmodel_ext/pack_block_cpu.cpp include gptqmodel_ext/marlin/generate_kernels.py +include gptqmodel_ext/machete/generate.py recursive-exclude gptqmodel_ext __pycache__ *.pyc prune tests/ prune format/ diff --git a/README.md b/README.md index 04a5d6fe1..f008ef433 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,9 @@

## Latest News -* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by +* 10/21/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization. +`Machete` kernel added for Hopper+/Blackwell acceleration for gptq and awq models. `act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`, `gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`. * 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls. @@ -196,14 +197,14 @@ Native support support some of the most popular multi-modal models: GPT-QModel is validated for Linux, MacOS, and Windows 11: -| Platform | Device | | Optimized Arch | Kernels | -|-----------------|---------------| --- | -------------- |-----------------------------------------------| -| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch | -| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | -| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch | -| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused, Torch | -| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion | -| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch | +| Platform | Device | | Optimized Arch | Kernels | +|-----------------|---------------| --- | -------------- |--------------------------------------------------------| +| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Machete, Marlin, Exllama V2, Exallma V1, Triton, Torch | +| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | +| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch | +| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch | +| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion | +| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch | ## Install diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index c7fca5666..9a55fcab3 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -10,7 +10,6 @@ from importlib.metadata import PackageNotFoundError, version from typing import Dict, List, Optional, Union -import accelerate import torch import transformers @@ -38,6 +37,7 @@ from ..utils.backend import BACKEND from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear from ..utils.logger import setup_logger +from ..utils.machete import _validate_machete_device_support from ..utils.marlin import _validate_marlin_device_support from ..utils.model import ( auto_dtype, @@ -478,7 +478,6 @@ def skip(*args, **kwargs): init_contexts = [no_init_weights()] - layer_type = "" with (ContextManagers(init_contexts)): cls.before_model_load(cls, load_quantized_model=True) @@ -507,8 +506,7 @@ def skip(*args, **kwargs): # Get the first layer to determine layer type layers, _ = get_module_by_name_prefix(model, cls.extract_layers_node()) - layer0 = layers[0] - layer_type = layer0.__class__.__name__ + layers[0] modules = find_modules(model) ignore_modules = [cls.lm_head] + cls.get_base_modules(model) @@ -535,7 +533,6 @@ def skip(*args, **kwargs): device=device, ) - log.debug(f"Loader1: device_map {device_map}") if isinstance(device_map, str) and device_map not in [ "auto", "balanced", @@ -548,8 +545,8 @@ def skip(*args, **kwargs): ) + import torch - from typing import Dict, List, Optional def build_layerwise_device_map( model, @@ -643,7 +640,7 @@ def assign(mod, device_id): if owner: device_map.setdefault(owner, fallback_device) else: - log.debug(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.") + log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.") # ------------------------------------------------------------- # 6. Prune parent assignments that would override child devices @@ -657,11 +654,11 @@ def assign(mod, device_id): if child_name != name and child_name.startswith(f"{name}.") } if child_devices and (len(child_devices) > 1 or device_id not in child_devices): - log.debug(f"Loader: dropping parent '{name}' from device_map to preserve child placements.") + log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.") device_map.pop(name, None) # optional logging for debug - log.debug(f"Loader: Built map across {num_gpus} GPU(s), " + log.info(f"Loader: Built map across {num_gpus} GPU(s), " f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}") return device_map @@ -707,6 +704,16 @@ def assign(mod, device_id): qcfg.runtime_format = FORMAT.GPTQ_V2 + if backend == BACKEND.MACHETE: + if is_sharded: + raise ValueError( + "Format: The loading of sharded checkpoints with Machete is currently not supported." + ) + if not _validate_machete_device_support(): + raise ValueError( + f"Kernel: Machete kernel requires compute capability >= 9.0. Detected capability: {torch.cuda.get_device_capability()}" + ) + if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and ( preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN): if is_sharded: @@ -742,7 +749,7 @@ def assign(mod, device_id): # If we use marlin or bitblas to load the quantized model, the model is already a converted model, # and we no longer need to call load_checkpoint_in_model() - if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]: + if load_checkpoint_in_model and backend not in [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]: load_checkpoint_in_model_then_tie_weights( model, dtype=dtype, diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index 867874a26..6e002ae7a 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -241,8 +241,8 @@ def _replace_module(module, child, name, level: int = 0, debug: bool = False) -> def replace_module_with_hooked_legacy(module, level: int = 0, quant_lm_head: bool = False): - if level == 0: - log.info("Hooked Modules: Using legacy based config for targeting of modules") + # if level == 0: + # log.info("Hooked Modules: Using legacy based config for targeting of modules") for name, child in module.named_children(): if not quant_lm_head and hasattr(module, "get_output_embeddings") and child == module.get_output_embeddings(): diff --git a/gptqmodel/nn_modules/qlinear/awq_machete.py b/gptqmodel/nn_modules/qlinear/awq_machete.py new file mode 100644 index 000000000..622d7dd35 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/awq_machete.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import AWQuantLinear +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from ...utils.machete import ( + _validate_machete_device_support, + machete_import_exception, + machete_mm, + machete_prepack_B, + pack_quantized_values_into_int32, +) +from ...utils.marlin import replace_parameter, unpack_cols +from ...utils.marlin_scalar_type import scalar_types +from ...utils.rocm import IS_ROCM + + +log = setup_logger() + + +class AwqMacheteQuantLinear(AWQuantLinear): + SUPPORTS_BITS = [4, 8] + SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128] + SUPPORTS_DESC_ACT = [False] # AWQ kernels do not reorder activations + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128] + + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = False + + QUANT_TYPE = "awq_machete" + + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__( + self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = False, + **kwargs): + if machete_import_exception is not None: + raise ValueError( + "Trying to use the machete backend, but could not import the " + f"C++/CUDA dependencies with the following error: {machete_import_exception}" + ) + + if bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {bits}. Supported: {list(self.TYPE_MAP.keys())}") + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.MACHETE), + adapter=adapter, + register_buffers=register_buffers, + **kwargs) + + self.weight_type = self.TYPE_MAP[self.bits] + self.has_zero_points = True + + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if machete_import_exception is not None: + return False, ImportError(machete_import_exception) + return cls._validate(**args) + + @classmethod + def validate_device(cls, device: DEVICE): + super().validate_device(device) + if device == DEVICE.CUDA: + if IS_ROCM: + raise NotImplementedError("Machete kernel is not supported on ROCm.") + if not _validate_machete_device_support(): + raise NotImplementedError("Machete kernel requires compute capability >= 9.0.") + + def post_init(self): + device = self.qweight.device + + # Reconstruct integer weights from packed AWQ representation + qweight_int = unpack_cols( + self.qweight, + self.bits, + self.in_features, + self.out_features, + ).to(device=device) + + packed = pack_quantized_values_into_int32( + qweight_int, + self.weight_type, + packed_dim=0, + ) + packed = packed.t().contiguous().t() + prepacked = machete_prepack_B( + packed, + a_type=self.scales.dtype, + b_type=self.weight_type, + group_scales_type=self.scales.dtype, + ) + replace_parameter( + self, + "qweight", + torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), + ) + + # Ensure scales are contiguous and resident on the correct device. + replace_parameter( + self, + "scales", + torch.nn.Parameter(self.scales.contiguous(), requires_grad=False), + ) + + # Convert zero-points: unpack columns, then pre-apply scales as expected by machete_mm + effective_group_size = self.in_features if self.group_size == -1 else self.group_size + num_groups = self.in_features // effective_group_size + + qzeros_unpacked = unpack_cols( + self.qzeros, + self.bits, + num_groups, + self.out_features, + ).to(device=device) + + scales = self.scales + qzeros_fp = (-1.0 * scales.to(dtype=scales.dtype) * qzeros_unpacked.to(scales.dtype)).contiguous() + replace_parameter( + self, + "qzeros", + torch.nn.Parameter(qzeros_fp, requires_grad=False), + ) + + if self.bias is not None: + self.bias = self.bias.to(device=device) + + super().post_init() + + def forward(self, x: torch.Tensor): + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + + input_2d = x.reshape(-1, x.shape[-1]) + group_scales = self.scales.to(dtype=input_2d.dtype) + group_zeros = self.qzeros.to(dtype=input_2d.dtype) + + output = machete_mm( + a=input_2d, + b_q=self.qweight, + b_type=self.weight_type, + b_group_scales=group_scales, + b_group_zeros=group_zeros, + b_group_size=self.group_size, + ) + + if self.bias is not None: + output.add_(self.bias) + + result = output.reshape(x.shape[:-1] + (self.out_features,)) + + if self.adapter: + result = self.adapter.apply(x=x, out=result) + + return result + + +__all__ = ["AwqMacheteQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/machete.py b/gptqmodel/nn_modules/qlinear/machete.py new file mode 100644 index 000000000..177979e29 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/machete.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import BaseQuantLinear +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from ...utils.machete import ( + _validate_machete_device_support, + check_machete_supports_shape, + machete_import_exception, + machete_mm, + machete_prepack_B, + pack_quantized_values_into_int32, + query_machete_supported_group_sizes, + unpack_quantized_values_into_int32, +) +from ...utils.marlin import replace_parameter +from ...utils.marlin_scalar_type import scalar_types +from ...utils.rocm import IS_ROCM + + +log = setup_logger() + + +class MacheteQuantLinear(BaseQuantLinear): + SUPPORTS_BITS = [4, 8] + SUPPORTS_GROUP_SIZE = [-1, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128] + + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = False + + QUANT_TYPE = "machete" + + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + def __init__( + self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + register_buffers: bool = False, + adapter: Adapter = None, + **kwargs): + if machete_import_exception is not None: + raise ValueError( + "Trying to use the machete backend, but could not import the " + f"C++/CUDA dependencies with the following error: {machete_import_exception}" + ) + + if (bits, sym) not in self.TYPE_MAP: + raise ValueError(f"Unsupported quantization config: bits={bits}, sym={sym}") + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.MACHETE), + adapter=adapter, + register_buffers=False, + **kwargs) + + # Quantized weights (packed) + self.register_parameter( + "qweight", + torch.nn.Parameter( + torch.empty( + self.in_features // self.pack_factor, + self.out_features, + dtype=torch.int32, + ), + requires_grad=False, + ), + ) + + # Activation order indices + self.register_parameter( + "g_idx", + torch.nn.Parameter( + torch.empty(self.in_features, dtype=torch.int32), + requires_grad=False, + ), + ) + + # Scales + scales_rows = self.in_features if self.group_size == -1 else self.in_features // self.group_size + self.register_parameter( + "scales", + torch.nn.Parameter( + torch.empty( + scales_rows, + self.out_features, + dtype=torch.float16, + ), + requires_grad=False, + ), + ) + + # Zero points unused for symmetric GPTQ + self.register_parameter( + "qzeros", + torch.nn.Parameter( + torch.empty(0, dtype=torch.float16), + requires_grad=False, + ), + ) + + if bias: + self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16)) + else: + self.bias = None + + self.weight_type = self.TYPE_MAP[(self.bits, sym)] + self.has_zero_points = False + + # Buffer storing permutation applied to activations (empty when unused) + self.register_buffer("input_perm", torch.empty(0, dtype=torch.int32)) + + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if machete_import_exception is not None: + return False, ImportError(machete_import_exception) + + ok, err = cls._validate(**args) + if not ok: + return ok, err + + in_features = args.get("in_features") + out_features = args.get("out_features") + if in_features is not None and out_features is not None: + supported, reason = check_machete_supports_shape(in_features, out_features) + if not supported: + return False, ValueError(reason) + + bits = args.get("bits") + sym = args.get("sym", True) + quant_type = cls.TYPE_MAP.get((bits, sym)) + if quant_type is None: + return False, ValueError(f"Machete does not support bits={bits}, sym={sym}") + + group_size = args.get("group_size") + dtype = args.get("dtype", torch.float16) + if group_size not in query_machete_supported_group_sizes(dtype): + return False, ValueError( + f"Machete does not support group_size={group_size} for dtype={dtype}" + ) + + return True, None + + @classmethod + def validate_device(cls, device: DEVICE): + super().validate_device(device) + if device == DEVICE.CUDA: + if IS_ROCM: + raise NotImplementedError("Machete kernel is not supported on ROCm.") + if not _validate_machete_device_support(): + raise NotImplementedError("Machete kernel requires compute capability >= 9.0.") + + def post_init(self): + device = self.qweight.device + + perm = None + if self.desc_act: + perm = torch.argsort(self.g_idx).to(torch.int32) + sorted_g_idx = self.g_idx[perm] + replace_parameter( + self, + "g_idx", + torch.nn.Parameter(sorted_g_idx.to(device=device), requires_grad=False), + ) + self.input_perm = perm.to(device=device) + else: + self.input_perm = torch.empty(0, dtype=torch.int32, device=device) + + qweight_unpacked = unpack_quantized_values_into_int32( + self.qweight.data, self.weight_type, packed_dim=0) + if perm is not None: + qweight_unpacked = qweight_unpacked[perm, :] + + qweight_packed = pack_quantized_values_into_int32( + qweight_unpacked, self.weight_type, packed_dim=0) + qweight_packed = qweight_packed.t().contiguous().t() + prepacked = machete_prepack_B( + qweight_packed, + a_type=self.scales.dtype, + b_type=self.weight_type, + group_scales_type=self.scales.dtype, + ) + replace_parameter( + self, + "qweight", + torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), + ) + + replace_parameter( + self, + "scales", + torch.nn.Parameter(self.scales.data.contiguous(), requires_grad=False), + ) + + replace_parameter( + self, + "qzeros", + torch.nn.Parameter(torch.empty(0, dtype=self.scales.dtype, device=device), requires_grad=False), + ) + self.has_zero_points = False + + if self.bias is not None: + self.bias = self.bias.to(device=device) + + super().post_init() + + def list_buffers(self) -> List: + buf = super().list_buffers() + if hasattr(self, "input_perm") and self.input_perm is not None: + buf.append(self.input_perm) + return buf + + def forward(self, x: torch.Tensor): + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + + input_2d = x.reshape(-1, x.shape[-1]) + + if self.input_perm.numel() > 0: + perm = self.input_perm + if perm.device != input_2d.device: + perm = perm.to(device=input_2d.device) + input_2d = input_2d[:, perm] + + group_scales = self.scales + if group_scales.dtype != input_2d.dtype: + group_scales = group_scales.to(dtype=input_2d.dtype) + + group_zeros = self.qzeros if self.has_zero_points and self.qzeros.numel() > 0 else None + + output = machete_mm( + a=input_2d, + b_q=self.qweight, + b_type=self.weight_type, + b_group_scales=group_scales, + b_group_zeros=group_zeros, + b_group_size=self.group_size, + ) + + if self.bias is not None: + output.add_(self.bias) + + result = output.reshape(x.shape[:-1] + (self.out_features,)) + + if self.adapter: + result = self.adapter.apply(x=x, out=result) + + return result + + +__all__ = ["MacheteQuantLinear"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 58af9135e..664e0e40c 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -17,6 +17,7 @@ class BACKEND(str, Enum): EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1 EXLLAMA_V2 = "exllama_v2" # FASTER: optimized for batching > 1 EXLLAMA_EORA = "exllama_eora" + MACHETE = "machete" # CUTLASS-based kernel optimized for Hopper (SM90+) MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower) MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster) BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache) diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 2cf2db75d..3bb46f107 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -18,11 +18,13 @@ from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear +from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear +from ..nn_modules.qlinear.machete import MacheteQuantLinear from ..nn_modules.qlinear.marlin import MarlinQuantLinear from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear @@ -45,6 +47,7 @@ AUTO_SELECT_BACKEND_ORDER_MAP = { METHOD.GPTQ: OrderedDict({ + BACKEND.MACHETE: MacheteQuantLinear, # optimized for sm90+ BACKEND.MARLIN: MarlinQuantLinear, # optimized for bs > 1 # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, # BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, # optimized for bs > 1 @@ -59,6 +62,7 @@ BACKEND.QQQ: QQQQuantLinear, # qqq kernel based on marlin }), METHOD.AWQ: OrderedDict({ + BACKEND.MACHETE: AwqMacheteQuantLinear, BACKEND.MARLIN: AwqMarlinQuantLinear, BACKEND.EXLLAMA_V2: AwqExllamaV2QuantLinear, BACKEND.EXLLAMA_V1: AwqExllamaQuantLinear, @@ -70,7 +74,7 @@ SUPPORTS_BACKEND_MAP = { METHOD.GPTQ: { - FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA], + FORMAT.GPTQ: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA], FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH], FORMAT.MARLIN: [BACKEND.MARLIN, BACKEND.MARLIN_FP16], FORMAT.BITBLAS: [BACKEND.BITBLAS], @@ -79,10 +83,10 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], + FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], - FORMAT.MARLIN: [BACKEND.MARLIN], + FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], } } @@ -280,6 +284,11 @@ def select_quant_linear( qlinear = TritonV2QuantLinear elif backend == BACKEND.BITBLAS: qlinear = BitBLASQuantLinear + elif backend == BACKEND.MACHETE: + if quant_method == METHOD.AWQ: + qlinear = AwqMacheteQuantLinear + else: + qlinear = MacheteQuantLinear elif backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16]: if quant_method == METHOD.AWQ: qlinear = AwqMarlinQuantLinear diff --git a/gptqmodel/utils/machete.py b/gptqmodel/utils/machete.py new file mode 100644 index 000000000..57aaee535 --- /dev/null +++ b/gptqmodel/utils/machete.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from typing import List, Optional + +import torch + +from ._extension_loader import load_extension_module +from .logger import setup_logger +from .marlin_scalar_type import ScalarType, scalar_types + + +log = setup_logger() + +machete_import_exception: Optional[str] = None +try: + gptqmodel_machete_kernels = load_extension_module("gptqmodel_machete_kernels") +except ImportError as e: # pragma: no cover - surfaced at runtime + machete_import_exception = str(e) + gptqmodel_machete_kernels = None + +MACHETE_PREPACKED_BLOCK_SHAPE = (64, 128) + + +def _validate_machete_device_support() -> bool: + return (torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9) + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(_zero_points: bool) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + +def query_machete_supported_group_sizes(act_type: torch.dtype) -> List[int]: + if act_type in (torch.float16, torch.bfloat16): + return [-1, 64, 128] + return [-1, 128] + + +def check_machete_supports_shape(in_features: int, + out_features: int) -> tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return (False, + f"Input features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]}") + if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return (False, + f"Output features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]}") + return (True, None) + + +def _ensure_machete_loaded(): + if machete_import_exception is not None: + raise ImportError( + f"Trying to use the machete backend, but could not import the C++/CUDA dependencies: {machete_import_exception}" + ) + + +def _maybe_scalar_type(t: Optional[torch.Tensor]) -> Optional[torch.dtype]: + return t.dtype if t is not None else None + + +def machete_prepack_B( + weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + _ensure_machete_loaded() + return gptqmodel_machete_kernels.machete_prepack_B( + weight, + a_type, + b_type.id, + group_scales_type, + ) + + +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype] = None, + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + _ensure_machete_loaded() + return gptqmodel_machete_kernels.machete_supported_schedules( + a_type, + b_type.id, + group_scales_type, + group_zeros_type, + channel_scales_type, + token_scales_type, + out_type, + ) + + +def machete_mm( + *, + a: torch.Tensor, + b_q: torch.Tensor, + b_type: ScalarType, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + schedule: Optional[str] = None) -> torch.Tensor: + _ensure_machete_loaded() + return gptqmodel_machete_kernels.machete_mm( + a, + b_q, + b_type.id, + out_type, + b_group_scales, + b_group_zeros, + b_group_size, + b_channel_scales, + a_token_scales, + schedule, + ) + + +def pack_quantized_values_into_int32( + tensor: torch.Tensor, + qtype: ScalarType, + packed_dim: int = 0) -> torch.Tensor: + perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + temp = tensor.permute(perm) + + pack_factor = 32 // qtype.size_bits + mask = (1 << qtype.size_bits) - 1 + + assert temp.shape[-1] % pack_factor == 0 + new_shape = list(temp.shape) + new_shape[-1] //= pack_factor + + result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device) + for i in range(pack_factor): + result |= ((temp[..., i::pack_factor] & mask) << (qtype.size_bits * i)) + + return result.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + tensor: torch.Tensor, + qtype: ScalarType, + packed_dim: int = 0) -> torch.Tensor: + perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + temp = tensor.permute(perm) + + pack_factor = 32 // qtype.size_bits + mask = (1 << qtype.size_bits) - 1 + + new_shape = list(temp.shape) + new_shape[-1] *= pack_factor + + result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device) + for i in range(pack_factor): + result[..., i::pack_factor] = (temp >> (qtype.size_bits * i)) & mask + + return result.permute(inv_perm) + + +__all__ = [ + "_validate_machete_device_support", + "check_machete_supports_shape", + "machete_import_exception", + "machete_mm", + "machete_prepack_B", + "machete_supported_schedules", + "pack_quantized_values_into_int32", + "query_machete_supported_act_types", + "query_machete_supported_group_sizes", + "query_machete_supported_quant_types", + "unpack_quantized_values_into_int32", +] diff --git a/gptqmodel_ext/cutlass_extensions/__init__.py b/gptqmodel_ext/cutlass_extensions/__init__.py new file mode 100644 index 000000000..a903f4cee --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/__init__.py @@ -0,0 +1 @@ +# Cutlass extension helpers for GPTQModel diff --git a/gptqmodel_ext/cutlass_extensions/common.cpp b/gptqmodel_ext/cutlass_extensions/common.cpp new file mode 100644 index 000000000..3d2093ab9 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/gptqmodel_ext/cutlass_extensions/common.hpp b/gptqmodel_ext/cutlass_extensions/common.hpp new file mode 100644 index 000000000..f2c1dcf69 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/common.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm90_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm120_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/gptqmodel_ext/cutlass_extensions/cute_utils.cuh b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh new file mode 100644 index 000000000..f61fe3ceb --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) { + return true; + } else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 000000000..5c1d6e3f4 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 000000000..7aa87feb4 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,497 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 000000000..58b1e8ff1 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 000000000..ad8c0067d --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,321 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace vllm::c2x diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 000000000..c43eea0a0 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,450 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + template + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::homogeneous_multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(float const* const* a_scales_ptr, + float const* const* b_scales_ptr, + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +}; // namespace vllm::c3x diff --git a/gptqmodel_ext/cutlass_extensions/torch_utils.hpp b/gptqmodel_ext/cutlass_extensions/torch_utils.hpp new file mode 100644 index 000000000..a1ff933cc --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + std::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/gptqmodel_ext/cutlass_extensions/vllm_collective_builder.cuh b/gptqmodel_ext/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 000000000..085ee1290 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/gptqmodel_ext/cutlass_extensions/vllm_custom_types.cuh b/gptqmodel_ext/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 000000000..6146bdc1f --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py b/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 000000000..34fb64c41 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + + +VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + }, +} + +VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + }, +} + +VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + }, +} + +VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + +VLLMKernelScheduleTag: dict[MixedInputKernelScheduleType | KernelScheduleType, str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501 + }, +} diff --git a/gptqmodel_ext/cutlass_extensions/vllm_numeric_conversion.cuh b/gptqmodel_ext/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 000000000..90f226cf6 --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,992 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { + if constexpr (sizeof(PackedSrc) == 1) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 2) { + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; + } else { + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_regs(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/gptqmodel_ext/cutlass_extensions/vllm_type_utils.cuh b/gptqmodel_ext/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 000000000..500ed508c --- /dev/null +++ b/gptqmodel_ext/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/gptqmodel_ext/machete/Readme.md b/gptqmodel_ext/machete/Readme.md new file mode 100644 index 000000000..6ffb2416b --- /dev/null +++ b/gptqmodel_ext/machete/Readme.md @@ -0,0 +1,45 @@ +# Machete (Mixed Precision Cutlass-Based GEMM) + +Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin. + +## Overview + +Machete effectively performs + +```python +scale_type = w_s.dtype +compute_type = a.dtype +out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a +``` + +Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and +`w_z` is the quantization zeropoints. + +> **_NOTE:_** `w_z` is added after the scales so we can +use FMA operations, but this means they must have the scales pre-applied if the +supplied zeropoints assume that they will be subtracted before the scales are +applied. + +## API + +The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: + +```python +from vllm import _custom_ops as ops + +... +W_q_packed = ops.machete_prepack_B(w_q, wtype) +output = ops.machete_gemm( + a, + b_q=W_q_packed, + b_type=wtype, + b_scales=w_s, + b_group_size=group_size +) +``` + +## Code Generation + +Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. + +New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. diff --git a/gptqmodel_ext/machete/core/registration.h b/gptqmodel_ext/machete/core/registration.h new file mode 100644 index 000000000..c83d6cebf --- /dev/null +++ b/gptqmodel_ext/machete/core/registration.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } \ No newline at end of file diff --git a/gptqmodel_ext/machete/core/scalar_type.hpp b/gptqmodel_ext/machete/core/scalar_type.hpp new file mode 100644 index 000000000..97078169d --- /dev/null +++ b/gptqmodel_ext/machete/core/scalar_type.hpp @@ -0,0 +1,362 @@ +#pragma once + +// For TORCH_CHECK +#include + +#include + +namespace vllm { + +template +inline To bit_cast_like(const From& src) noexcept { + static_assert(sizeof(To) == sizeof(From), + "bit_cast_like requires source and destination to be the same size"); + To dst{}; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + return bit_cast_like(double_raw); + } + + std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || (size_bits() == 64 && is_signed()), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = bit_cast_like(max); + uint64_t min_raw = max_raw | sign_bit_double; + return {bit_cast_like(min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/gptqmodel_ext/machete/generate.py b/gptqmodel_ext/machete/generate.py new file mode 100644 index 000000000..bea0bddda --- /dev/null +++ b/gptqmodel_ext/machete/generate.py @@ -0,0 +1,718 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +import math +import os +import shutil +from collections.abc import Iterable +from copy import deepcopy +from dataclasses import dataclass, fields +from functools import reduce +from pathlib import Path + +import jinja2 + +import sys + +_ROOT = Path(__file__).resolve().parents[2] +_CUTLASS_EXT_DIR = _ROOT / "gptqmodel_ext" / "cutlass_extensions" + +_CUTLASS_ROOT = os.environ.get("GPTQMODEL_CUTLASS_DIR") +if _CUTLASS_ROOT is not None: + _CUTLASS_ROOT = Path(_CUTLASS_ROOT) +else: + _CUTLASS_ROOT = _ROOT / "cutlass" + +_CUTLASS_PYTHON_DIR = _CUTLASS_ROOT / "python" + +if str(_CUTLASS_EXT_DIR) not in sys.path: + sys.path.append(str(_CUTLASS_EXT_DIR)) +if _CUTLASS_PYTHON_DIR.exists() and str(_CUTLASS_PYTHON_DIR) not in sys.path: + sys.path.append(str(_CUTLASS_PYTHON_DIR)) +if not _CUTLASS_PYTHON_DIR.exists(): + raise RuntimeError( + "CUTLASS python bindings not found. Set GPTQMODEL_CUTLASS_DIR to a valid CUTLASS checkout." + ) + +from vllm_cutlass_library_extension import ( + DataType, + EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, + VLLMDataType, + VLLMDataTypeNames, + VLLMDataTypeSize, + VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, + VLLMKernelScheduleTag, +) + +# +# Generator templating +# + +DISPATCH_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { + +{% for impl_config in impl_configs %} +{% set type_sig = gen_type_sig(impl_config.types) -%} +{% for s in impl_config.schedules %} +extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs); +{%- endfor %} + +torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.maybe_schedule) { + {%- for cond, s in impl_config.heuristic %} + {%if cond is not none%}if ({{cond}}) + {%- else %}else + {%- endif %} + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %} + } + + {%- for s in impl_config.schedules %} + if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}") + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args); + {%- endfor %} + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.maybe_schedule); +} +{%- endfor %} + + +static inline std::optional maybe_scalartype( + std::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set type_sig = gen_type_sig(t) -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!maybe_g_scales_type{%endif%} + && {%if t.b_group_zeropoint != void -%} + maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!maybe_g_zeros_type{%endif%} + && {%if t.b_channel_scale != void -%} + maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} + {%- else %}!maybe_ch_scales_type{%endif%} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} + {%- else %}!maybe_tok_scales_type{%endif%} + ) { + return mm_dispatch_{{type_sig}}(args); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \\n", + {%- for impl_config in impl_configs %} + {% set t = impl_config.types -%} + "\\t{{gen_type_option_name(t)}}\\n", + {%- endfor %} + ""); +} + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set schs = impl_config.schedules -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && args.a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!args.maybe_group_scales_type{%endif%} + && {%if t.b_group_zeropoint != void-%} + args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!args.maybe_group_zeros_type{%endif%} + ) { + return { + {%- for s in impl_config.schedules %} + "{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %} + {%- endfor %} + }; + } + {%- endfor %} + + return {}; +}; + +}; // namespace machete +""" + +IMPL_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { + +{% for sch in unique_schedules(impl_configs) %} +{% set sch_sig = gen_sch_sig(sch) -%} +struct sch_{{sch_sig}} { + using TileShapeNM = Shape<{{ + to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; + using ClusterShape = Shape<{{ + to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; + // TODO: Reimplement + // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; + using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; +{% endfor %} + +{% for impl_config in impl_configs %} +{% set t = impl_config.types -%} +{% set schs = impl_config.schedules -%} +{% set type_sig = gen_type_sig(t) -%} + +template +using Kernel_{{type_sig}} = MacheteKernelTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[t.b]}}, // ElementB + {{DataTypeTag[t.out]}}, // ElementD + {{DataTypeTag[t.accumulator]}}, // Accumulator + {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT + {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT + {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Sch>; + +{% for sch in schs %} +{% set sch_sig = gen_sch_sig(sch) -%} +torch::Tensor +impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) { + return run_impl>(args); +} +{%- endfor %} +{%- endfor %} + +}; // namespace machete +""" + +PREPACK_TEMPLATE = """ +#include "../machete_prepack_launcher.cuh" + +namespace machete { + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + {%- for t in types %} + {% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %} + if (args.a_type == {{TorchTypeTag[t.a]}} + && args.b_type.size_bits() == {{t.b_num_bits}} + && convert_type == {{TorchTypeTag[t.convert]}}) { + return prepack_impl< + PrepackedLayoutBTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[b_type]}}, // ElementB + {{DataTypeTag[t.convert]}}, // ElementConvert + {{DataTypeTag[t.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperative> + >(args.B); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); +} + +}; // namespace machete +""" + +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative +TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative + + +@dataclass(frozen=True) +class ScheduleConfig: + tile_shape_mn: tuple[int, int] + cluster_shape_mnk: tuple[int, int, int] + kernel_schedule: MixedInputKernelScheduleType + epilogue_schedule: EpilogueScheduleType + tile_scheduler: TileSchedulerType + + +@dataclass(frozen=True) +class TypeConfig: + a: DataType + b: DataType | VLLMDataType + b_group_scale: DataType + b_group_zeropoint: DataType + b_channel_scale: DataType + a_token_scale: DataType + out: DataType + accumulator: DataType + + +@dataclass(frozen=True) +class PrepackTypeConfig: + a: DataType + b_num_bits: int + convert: DataType + accumulator: DataType + + +@dataclass +class ImplConfig: + types: TypeConfig + schedules: list[ScheduleConfig] + heuristic: list[tuple[str | None, ScheduleConfig]] + + +def generate_sch_sig(schedule_config: ScheduleConfig) -> str: + tile_shape = ( + f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" + ) + cluster_shape = ( + f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}" + ) + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split( + "::" + )[-1] + epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split( + "::" + )[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1] + + return ( + f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}" + ) + + +# mostly unique shorter sch_sig +def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: + kernel_terse_names_replace = { + "KernelTmaWarpSpecializedCooperative": "TmaMI_", + "TmaWarpSpecializedCooperative_": "TmaCoop_", + "StreamKScheduler": "streamK", + } + + sch_sig = generate_sch_sig(schedule_config) + for orig, terse in kernel_terse_names_replace.items(): + sch_sig = sch_sig.replace(orig, terse) + return sch_sig + + +# unique type_name +def generate_type_signature(kernel_types: TypeConfig): + return str( + "".join( + [ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + ) + + +def generate_type_option_name(kernel_types: TypeConfig): + return ", ".join( + [ + f"{field.name.replace('b_', 'with_') + '_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + + +def is_power_of_two(n): + return (n != 0) and (n & (n - 1) == 0) + + +def to_cute_constant(value: list[int]): + def _to_cute_constant(value: int): + if is_power_of_two(value): + return f"_{value}" + else: + return f"Int<{value}>" + + if isinstance(value, Iterable): + return [_to_cute_constant(value) for value in value] + else: + return _to_cute_constant(value) + + +def unique_schedules(impl_configs: list[ImplConfig]): + # Use dict over set for deterministic ordering + return list( + { + sch: None for impl_config in impl_configs for sch in impl_config.schedules + }.keys() + ) + + +def unsigned_type_with_bitwidth(num_bits): + return { + 4: DataType.u4, + 8: DataType.u8, + 16: DataType.u16, + 32: DataType.u32, + 64: DataType.u64, + }[num_bits] + + +template_globals = { + "void": DataType.void, + "DataTypeTag": VLLMDataTypeTag, + "VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag, + "TorchTypeTag": VLLMDataTypeTorchDataTypeTag, + "KernelScheduleTag": VLLMKernelScheduleTag, + "EpilogueScheduleTag": EpilogueScheduleTag, + "TileSchedulerTag": TileSchedulerTag, + "to_cute_constant": to_cute_constant, + "gen_sch_sig": generate_terse_sch_sig, + "gen_type_sig": generate_type_signature, + "unique_schedules": unique_schedules, + "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, + "gen_type_option_name": generate_type_option_name, +} + + +def create_template(template_str): + template = jinja2.Template(template_str) + template.globals.update(template_globals) + return template + + +mm_dispatch_template = create_template(DISPATCH_TEMPLATE) +mm_impl_template = create_template(IMPL_TEMPLATE) +prepack_dispatch_template = create_template(PREPACK_TEMPLATE) + + +def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): + sources = [] + + sources.append( + ( + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), + ) + ) + + prepack_types = [] + for impl_config in impl_configs: + convert_type = ( + impl_config.types.a + if impl_config.types.b_group_scale == DataType.void + else impl_config.types.b_group_scale + ) + prepack_types.append( + PrepackTypeConfig( + a=impl_config.types.a, + b_num_bits=VLLMDataTypeSize[impl_config.types.b], + convert=convert_type, + accumulator=impl_config.types.accumulator, + ) + ) + + def prepacked_type_key(prepack_type: PrepackTypeConfig): + # For now, we can just use the first accumulator type seen since + # the tensor core shapes/layouts don't vary based on accumulator + # type so we can generate less code this way + return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) + + unique_prepack_types = [] + prepack_types_seen = set() + for prepack_type in prepack_types: + key = prepacked_type_key(prepack_type) + if key not in prepack_types_seen: + unique_prepack_types.append(prepack_type) + prepack_types_seen.add(key) + + sources.append( + ( + "machete_prepack", + prepack_dispatch_template.render( + types=unique_prepack_types, + ), + ) + ) + + # Split up impls across files + num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) + num_impls_per_file = math.ceil(num_impls / num_impl_files) + + files_impls: list[list[ImplConfig]] = [[]] + + curr_num_impls_assigned = 0 + curr_impl_in_file = 0 + curr_impl_configs = deepcopy(list(reversed(impl_configs))) + + while curr_num_impls_assigned < num_impls: + room_left_in_file = num_impls_per_file - curr_impl_in_file + if room_left_in_file == 0: + files_impls.append([]) + room_left_in_file = num_impls_per_file + curr_impl_in_file = 0 + + curr_ic = curr_impl_configs[-1] + if len(curr_ic.schedules) >= room_left_in_file: + # Break apart the current impl config + tmp_ic = deepcopy(curr_ic) + tmp_ic.schedules = curr_ic.schedules[:room_left_in_file] + curr_ic.schedules = curr_ic.schedules[room_left_in_file:] + files_impls[-1].append(tmp_ic) + else: + files_impls[-1].append(curr_ic) + curr_impl_configs.pop() + curr_num_impls_assigned += len(files_impls[-1][-1].schedules) + curr_impl_in_file += len(files_impls[-1][-1].schedules) + + for part, file_impls in enumerate(files_impls): + sources.append( + ( + f"machete_mm_impl_part{part + 1}", + mm_impl_template.render(impl_configs=file_impls), + ) + ) + + return sources + + +def generate(): + # See csrc/quantization/machete/Readme.md, the Codegeneration for more info + # about how this works + SCRIPT_DIR = os.path.dirname(__file__) + + sch_common_params = dict( + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + ) + + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + default_tile_heuristic_config = { + #### M = 257+ + "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + "M > 256": ((128, 256), (2, 1, 1)), + #### M = 129-256 + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + "M > 128": ((128, 256), (2, 1, 1)), + #### M = 65-128 + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), + #### M = 33-64 + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), + #### M = 17-32 + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), + #### M = 1-16 + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore + for cond, tile_config in default_tile_heuristic_config.items() + ] + + def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) + return schedules + + impl_configs = [] + + GPTQ_kernel_type_configs = list( + TypeConfig( + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, + accumulator=DataType.f32, + ) + for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16) + ) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip( + GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) + ] + + AWQ_kernel_type_configs = list( + TypeConfig( + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, + accumulator=DataType.f32, + ) + for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16) + ) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip( + AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) + ] + + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } + + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] + + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] + + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] + + output_dir = os.path.join(SCRIPT_DIR, "generated") + + # Delete the "generated" directory if it exists + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + # Create the "generated" directory + os.makedirs(output_dir) + + # Render each group of configurations into separate files + for filename, code in create_sources(impl_configs): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") + + +if __name__ == "__main__": + generate() diff --git a/gptqmodel_ext/machete/machete_collective_builder.cuh b/gptqmodel_ext/machete/machete_collective_builder.cuh new file mode 100644 index 000000000..ee825583d --- /dev/null +++ b/gptqmodel_ext/machete/machete_collective_builder.cuh @@ -0,0 +1,31 @@ +#pragma once + +#include "cutlass_extensions/vllm_collective_builder.cuh" +#include "machete_mainloop.cuh" + +namespace cutlass::gemm::collective { +using namespace cute; + +struct MacheteKernelTag {}; + +template +struct VLLMCollectiveBuilder< + MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, + ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, + KernelScheduleType, + cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v)>> { + using CollectiveOp = machete::MacheteCollectiveMma< + ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, + AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType>; +}; + +}; // namespace cutlass::gemm::collective diff --git a/gptqmodel_ext/machete/machete_interleaving_utils.cuh b/gptqmodel_ext/machete/machete_interleaving_utils.cuh new file mode 100644 index 000000000..d397f87f1 --- /dev/null +++ b/gptqmodel_ext/machete/machete_interleaving_utils.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace machete { + +using namespace cute; + +// get an interleaved block layout where each element consecutive element has a +// stride of bit_stride and the block width is blk_bit_width, +// examples: +// size_bits = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1 +// size_bits = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1) +// size_bits = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1) +// size_bits = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1) +template +CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() { + static_assert(blk_bit_width % bit_stride == 0); + static_assert(bit_stride % cute::sizeof_bits_v == 0); + + constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v; + + if constexpr (cute::sizeof_bits_v == bit_stride) { + // identity layout + return Layout>>{}; + } else { + constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v; + constexpr auto num_strides = elems_per_blk / elems_per_stride; + return Layout, Int>, + Stride, Int<1>>>{}; + } +} + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_mainloop.cuh b/gptqmodel_ext/machete/machete_mainloop.cuh new file mode 100644 index 000000000..2f52a6b7a --- /dev/null +++ b/gptqmodel_ext/machete/machete_mainloop.cuh @@ -0,0 +1,1473 @@ +// +// Based off of: +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// +#pragma once + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cutlass/detail/collective.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" + +namespace machete { + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cutlass::gemm::collective; +using namespace cutlass::gemm::collective::detail; + +template +struct MacheteCollectiveMma { + using Schedule = KernelScheduleType; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); + + public: + static constexpr bool ALayoutIsPrepacked = true; + + // Prepacked block shape (N is M in the transposed problem) + using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK; + // Prepacked blocks per dim for a single MMA tile + using PPBlocksPerTile_MK = decltype(make_shape( + size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), + size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); + + using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout; + + static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0, + "M in PPBlockShape_MK must evenly divide M TileShape_MNK"); + static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0, + "K in PPBlockShape_MK must evenly divide K TileShape_MNK"); + + using ArchTag = arch::Sm90; + using TileShape = TileShape_MNK; + using ClusterShape = ClusterShape_MNK; + using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>; + using StrideA = TagToStrideA_t; + using ElementB = ElementB_; + using StrideB = TagToStrideB_t; + using ElementAccumulator = ElementAccumulator_; + using ElementMma = ElementB; + using ElementATuple = + cute::conditional_t::value, + cute::tuple, ElementATuple_>; + + static constexpr cute::GMMA::Major GmmaMajorA = + gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // + using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; + using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; + + static constexpr bool IsANarrow = cutlass::sizeof_bits::value < + cutlass::sizeof_bits::value; + static_assert(IsANarrow, + "A must be the narrow one since its the one that flows through " + "registers."); + + public: + static constexpr int PipelineStages = + compute_stage_count_or_override_single_affine_transformed_input< + sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, + ElementZero, TileShape_MNK>(StageCountType{}); + + struct DispatchPolicy { + constexpr static int Stages = PipelineStages; + using ClusterShape = ClusterShape_MNK; + using Schedule = KernelScheduleType; + }; + + using GmemTiledCopyA = + decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = + decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + // ((T, V), (BlocksM, BlocksK), pipe) -> offset + using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutAtomARowMajor = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomScale = Layout< + Shape(SmemLayoutAtomARowMajor{})), cute::Int<1>>>; + + using SmemLayoutAtomB = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomB = void; + + // + // Validity checks + // + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + public: + // + // Type Aliases + // + using KernelSchedule = KernelScheduleType; + + // For cases where we can't have a void type, we can use this to allow the + // code to compile when the scale / zero is void. + using NonVoidElementScale = + cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = + cute::conditional_t, float, ElementZero>; + + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the + // code to compile when the scale is void. + using NonVoidStrideScale = + cute::conditional_t, + cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((cutlass::gemm::detail::is_k_major()), + "The transformed matrix (A) must be K-major."); + + static_assert((sizeof(ElementB) == 2) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element (matrix B) must be 2 bytes OR both " + "inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major " + "if B is scaled]."); + + static_assert(std::is_same_v, + "TiledMma::ValTypeC must be the same as ElementAccumulator."); + + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemCopyAtomScale = Copy_Atom; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = + cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), + shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, + "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), + Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major + // only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, + layout::ColumnMajor> && + cute::is_same_v, + layout::RowMajor>; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc " + "for this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for + // every main loop iteration. We must also handle updating the pipeline + // transaction bytes on the fly. NOTE: Deleting this assertion without + // required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, + "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + // clang-format on + + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(int32_t(0), int32_t(0), int32_t(0))))); + + using ATensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + shape(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(int32_t(0), int32_t(0), int32_t(0)))), + PrepackedStrideA{})); + + using BTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(StrideB{}, int32_t(0)), StrideB{})); + using ScaleTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + using ZeroTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { + return make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_scale( + ScaleTensor tensor_scale = ScaleTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_zero( + ZeroTensor tensor_zero = ZeroTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) { + return make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } + + public: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + // clang-format on + + // + // section setup end + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types + // Device side kernel params + struct Params { + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A()); + using TMA_Scale = decltype(make_tma_copy_scale()); + using TMA_Zero = decltype(make_tma_copy_zero()); + using TMA_B = decltype(make_tma_copy_B()); + + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) { + return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride)); + }; + + typename Params::TMA_A tma_load_a; + typename Params::TMA_B tma_load_b; + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); + tma_load_a = make_tma_copy_A( + make_logical_tensor(ptr_A, shape(layout), stride(layout))); + + tma_load_b = make_tma_copy_B( + make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + + int32_t scale_k = + (ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0; + int32_t group_size = (ModeHasScales) ? args.group_size : 0; + + if constexpr (ModeHasScales) { + tma_load_scale = make_tma_copy_scale( + make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS)); + } + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + tma_load_zero = make_tma_copy_zero( + make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return {tma_load_a, tma_load_b, tma_load_scale, + tma_load_zero, scale_k, group_size}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + // clang-format off + + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the + // contract Returned tuple must contain at least two elements, with the first + // two elements being: gA_mkl - The tma tensor, A after a local tile so it + // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local + // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be + // specified as needed by this collective. + // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the + // values within a prepacked block. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { + using X = Underscore; + auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL), + K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL); + + // (TILE_V,TILE_B,m,k,l) + auto make_gA_mkl = [&]() { + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); + return local_tile(mA_mkl, + make_shape(size<0>(layout), PPBlocksPerTile_MK{}), + make_coord(0, make_coord(_, _))); + }; + + // (TILE_N,TILE_K,n,k,l) + auto make_gB_nkl = [&]() { + Tensor mB_nkl = + mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); + return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gS_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gZ_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(), + make_gZ_mkl()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in load_init."); + } + } + + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + // clang-format on + + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // + // Perform a collective-scoped matrix multiply-accumulate + // Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for " + "RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset + auto constexpr smem_A = SmemLayoutA{}; + + // convert: + // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset + // to: + // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset + // which can be thought of as: + // (T, MMA, (MMA_M, MMA_K), pipe) -> offset + auto constexpr smem_A_mma_ = + make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A), + zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A)); + // flatten to: + // (T, MMA, MMA_M, MMA_K, pipe) -> offset + auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate fragments and descriptors + Tensor tCrA_load = make_tensor( + tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K) + Tensor tCrA_mma = make_fragment_like(tCrA_load); + + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + static constexpr int A_CPY_VEC = + decltype(max_common_vector(tCsA, tCrA_load)){}; + + static constexpr int CONVERSION_WIDTH = + std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); + + auto load_A_to_registers = [&](int read_stage) { + copy(create_auto_vectorizing_copy(), + tCsA(_, _, _, read_stage), tCrA_load(_, _, _)); + }; + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = + partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info( + tiled_mma, partitioned_extra_info, warp_group_thread_idx); + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + auto convert_A = [&, a_vec = Int{}](int k_block, + int read_stage) { + load_extra_info_to_registers(partitioned_extra_info, + copy_partitions_extra_info, k_block, + read_stage); + transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info, + k_block); + }; + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + load_A_to_registers(read_stage); + convert_A(0, read_stage); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, smem_pipe_read.index()); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to + // overwrite the A registers for the first mma. + warpgroup_wait(); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, + // so we can release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } else { + convert_A(k_block + 1, read_stage); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, read_stage); + } + } + } + + warpgroup_fence_operand(accum); + } + + // Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it + ++smem_pipe_release; + } + } + + private: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required + template + CUTLASS_DEVICE void load_extra_info_to_registers( + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, + int read_stage) { + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), + tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), + tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } + } + + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. + // Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock( + TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, + int const k_block) { + auto in = tCrA_load(_, _, k_block); + auto out = tCrA_mma(_, _, k_block); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + convert_tensor(in, out, vec_A); + } else if constexpr (ModeHasScales) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto converted_inputs = + make_fragment_like(tCrA_mma)(_, _, k_block); + auto scales = tCrS(_, _, 0); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, vec_A); + // Apply scales and broadcast across inputs, store in converted_inputs + + // We need to cast to nv_bfloat16 for the multiply since + // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to + // float, which nvcc will not optimize to using vectorized fma + // instructions (i.e. hfma.bf16_v2) + if constexpr (std::is_same_v) { + cute::transform( + recast(converted_inputs), recast(scales), + recast(converted_inputs), cute::multiplies{}); + } else { + cute::transform(converted_inputs, scales, converted_inputs, + cute::multiplies{}); + } + + // Apply zeros if required + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto converted_zeros = make_fragment_like(tCrZ)(_, _, 0); + + convert_tensor(tCrZ(_, _, 0), converted_zeros); + if constexpr (std::is_same_v) { + cute::transform(recast(converted_inputs), + recast(converted_zeros), + recast(converted_inputs), cute::plus{}); + } else { + cute::transform(converted_inputs, converted_zeros, converted_inputs, + cute::plus{}); + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } else { + static_assert(cutlass::detail::dependent_false, + "No A data is loaded."); + } + } + + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts + // Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + // This is an element-wise conversion where we expect both tensors to have + // the same layout. As a result, we can cast as a cutlass array to use the + // fast numeric converters without worrying about indexing into the layout. + constexpr int N = cosize_v; + + // The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, + "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, + "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, + "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), + "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, + "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + using Converter = cutlass::InterleavedNumericArrayConverter< + IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = + reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = + reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } +}; + +} // namespace machete diff --git a/gptqmodel_ext/machete/machete_mm_kernel.cuh b/gptqmodel_ext/machete/machete_mm_kernel.cuh new file mode 100644 index 000000000..cc50e68b0 --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_kernel.cuh @@ -0,0 +1,309 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" +#include "machete_collective_builder.cuh" +#include "machete_prepacked_layout.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +// NOTE This kernel computes D = alpha * A * B + beta * C by computing +// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma +// instructions only support sourcing from registers for the left-hand +// operand, we want to upconvert/decompress the quantized operand in +// register. Since the primary use case we want to support is Y = XW^t where +// W is quantized, in this situation or right-hand operand is quantized so +// we compute the transpose to move it to the left-hand side. +template +struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementC = cute::conditional_t; + using ElementAccumulator = AccumulatorT; + using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; + + using BTypeTuple = cute::conditional_t< + with_group_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; + + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; + + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::VLLMCollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + static Arguments create_arguments( + cudaStream_t stream, + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + std::optional const& maybe_g_scales, // scale_KxN matrix + std::optional const& maybe_g_zeros, // scale_KxN matrix + std::optional maybe_group_size, + std::optional const& maybe_ch_scales, // len N vector + std::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); + + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); + } else { + TORCH_CHECK(!S_group_ptr, "Scales not supported"); + } + + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); + } + + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); + } + + // Transpose A and D + // A doesn't need to be transposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + + MainloopArguments mainloop_arguments{}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } + + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = MainloopArguments{ + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; + } else { + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; + } + + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; + + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } + + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } + + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + } +}; + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_mm_launcher.cuh b/gptqmodel_ext/machete/machete_mm_launcher.cuh new file mode 100644 index 000000000..cabe0af46 --- /dev/null +++ b/gptqmodel_ext/machete/machete_mm_launcher.cuh @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" + +namespace machete { + +struct MMArgs { + torch::Tensor const& A; + torch::Tensor const& B; + vllm::ScalarType const& b_type; + std::optional const& maybe_out_type; + std::optional const& maybe_group_scales; + std::optional const& maybe_group_zeros; + std::optional maybe_group_size; + std::optional const& maybe_channel_scales; + std::optional const& maybe_token_scales; + std::optional maybe_schedule; +}; + +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + std::optional maybe_group_scales_type; + std::optional maybe_group_zeros_type; + std::optional maybe_channel_scales_type; + std::optional maybe_token_scales_type; + std::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + +template +torch::Tensor run_impl(MMArgs args) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); + + auto device = args.A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + int M = args.A.size(0); + int N = args.B.size(1); + int K = args.A.size(1); + + // Allocate output + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + + auto arguments = MacheteKernel::create_arguments( + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); + TORCH_CHECK(MacheteKernel::can_implement(arguments), + "Machete kernel cannot be run with these arguments"); + + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); + torch::Tensor workspace = torch::empty( + workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); + + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); + + return D; +}; + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepack_kernel.cuh b/gptqmodel_ext/machete/machete_prepack_kernel.cuh new file mode 100644 index 000000000..d002355ca --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepack_kernel.cuh @@ -0,0 +1,76 @@ +#pragma once + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); + + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); + + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); + + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition + auto fragment = make_tensor(shape(thr_tB_in_linear)); + + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); +} + +template +static void prepack_B_template( + cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) { + using TileShapeNKL = + decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); + auto ilvd_NKbNbKL_to_offset = + PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout)); + + TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); + + auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); + auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); + + auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); + + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); +} + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepack_launcher.cuh b/gptqmodel_ext/machete/machete_prepack_launcher.cuh new file mode 100644 index 000000000..634b651a4 --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepack_launcher.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include "machete_prepack_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" + +namespace machete { + +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + std::optional maybe_group_scales_type; +}; + +template +torch::Tensor prepack_impl(torch::Tensor const B) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; + + auto device = B.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = + (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; + + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); + + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on + + // Allocate output + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); + + prepack_B_template( + stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); + + return D; +}; + +torch::Tensor prepack_B_dispatch(PrepackBArgs args); + +}; // namespace machete \ No newline at end of file diff --git a/gptqmodel_ext/machete/machete_prepacked_layout.cuh b/gptqmodel_ext/machete/machete_prepacked_layout.cuh new file mode 100644 index 000000000..4a7d6341e --- /dev/null +++ b/gptqmodel_ext/machete/machete_prepacked_layout.cuh @@ -0,0 +1,253 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "machete_collective_builder.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +struct IlvBlkLayoutAuto {}; + +// This defines a prepacked layout for the B matrix, where the matrix is broken +// up into PPBlockShape_NK blocks. The data within each block is then compactly +// stored in memory such that when performing a TiledMMA operation with the same +// shape as prepacked block, all the data for a given thread is contiguous in +// memory. This allows us to use wider shared memory loads when loading B from +// shared memory. The values within a thread are also potentially interlaeved +// inorder to allow for more efficient upconverting. +// +// The contract here is that the `TiledMma` determined below matches the one +// ultimately used in the kernel. (this is also why the other element types are +// required along with the kernel schedule) +template +// clang-format on +struct PrepackedLayoutBTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementAccumulator = AccumulatorT; + using ElementMma = MmaType; + + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, + using IlvdBlkLayout = std::conditional_t< + std::is_same_v, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, + IlvBlkLayout_>; + + // TODO (LucasWilkinson): compare the performance for other sizes + // Prepacked block shape, smallest layout atom for loading into registers + // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits + using PPBlockShape_NK = Shape<_128, _64>; + + // Create the shape of the tile anticipated to be used by the GEMM kernel, + // when the kernel executes we will compute `Ct = Bt * At` since the + // quantized weights (B), must be the lhs operand so the flow through + // registers. + // The _128 here doesn't actually impact the shape of the stored tile directly + // but may impact the op selected by rs_op_selector + using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, + size<1>(PPBlockShape_NK{}))); + + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + // Prepacked block, (athrid, val) -> (N,K) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { + return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); + } + + // Prepacked block, (N,K) -> (athrid, val) + // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { + return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { + // Return iterleaved layout + return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { + auto layout_no_interleave = + make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + + if constexpr (std::is_same_v) { + return layout_no_interleave; + } else { + // interleave by transforming FrgV into interleaved blocks where each + // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is + // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) + // if FrgV is {A, B, C, D, E, F, G, H} + // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} + auto frgV = get<1, 0>(layout_no_interleave); + auto ilvdBlk = IlvdBlkLayout{}; + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); + auto ilvd_FrgV = make_layout( + make_shape(shape(ilvdBlk), Int{}), + make_stride(stride(ilvdBlk), size(ilvdBlk))); + + // Return iterleaved layout + return make_layout( + get<0>(layout_no_interleave), + make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); + } + } + + // Prepacked block, (M,K) -> (storage_offset) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { + // do (M,K) -> (athrid, val) -> (storage_idx) + return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); + } + + // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_TV_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) + // => ((athrid, val), (BlocksN, BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + // for 4-bit elements, having >= 64 values per column + // allows TMA to load full 32-byte sectors + auto inner_layout = + make_layout(make_shape(_256{}, size<0>(layout) / _256{})); + + return make_layout(inner_layout, get<1>(layout), get<2>(layout)); + } + + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, + // BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { + auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), + make_layout(size<1>(PPBlockShape_NK{}))); + + // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) + auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); + return tiled_A.compose(ppblock_TV_to_NK(), _); + } + + // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) + template + CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { + auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); + return blocked_product(ppblock_NK_to_TV(), + make_layout(shape<1>(TVbNbK_to_NKL_layout))); + } +}; + +}; // namespace machete diff --git a/gptqmodel_ext/machete/machete_pytorch.cu b/gptqmodel_ext/machete/machete_pytorch.cu new file mode 100644 index 000000000..05a51ee21 --- /dev/null +++ b/gptqmodel_ext/machete/machete_pytorch.cu @@ -0,0 +1,73 @@ +#include "machete_mm_launcher.cuh" +#include "machete_prepack_launcher.cuh" +#include "core/scalar_type.hpp" + +#include "core/registration.h" + +namespace machete { + +using namespace vllm; + +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + std::optional maybe_group_scales_type, + std::optional maybe_group_zeros_type, + std::optional maybe_channel_scales_type, + std::optional maybe_token_scales_type, + std::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, + }); +} + +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + std::optional const& maybe_out_type, + std::optional const& maybe_group_scales, + std::optional const& maybe_group_zeros, + std::optional maybe_group_size, + std::optional const& maybe_channel_scales, + std::optional const& maybe_token_scales, + std::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); +} + +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + std::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("machete_prepack_B", &prepack_B); + m.impl("machete_mm", &mm); +} + +// use CatchAll since supported_schedules has no tensor arguments +TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) { + m.impl("machete_supported_schedules", &supported_schedules); +} + +}; // namespace machete diff --git a/setup.py b/setup.py index 3be472315..1234b3bd2 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,53 @@ import re import subprocess import sys +import tarfile +import urllib.request from pathlib import Path +from shutil import rmtree from setuptools import find_namespace_packages, find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel +CUTLASS_VERSION = "3.5.0" +CUTLASS_RELEASE_URL = f"https://github.com/NVIDIA/cutlass/archive/refs/tags/v{CUTLASS_VERSION}.tar.gz" + + +def _ensure_cutlass_source() -> Path: + deps_dir = Path("build") / "_deps" + deps_dir.mkdir(parents=True, exist_ok=True) + + cutlass_root = deps_dir / f"cutlass-v{CUTLASS_VERSION}" + marker = cutlass_root / ".gptqmodel_complete" + if marker.exists(): + return cutlass_root.resolve() + + archive_path = deps_dir / f"cutlass-v{CUTLASS_VERSION}.tar.gz" + if not archive_path.exists(): + print(f"Downloading CUTLASS v{CUTLASS_VERSION} ...") + with urllib.request.urlopen(CUTLASS_RELEASE_URL) as response: + data = response.read() + archive_path.write_bytes(data) + + if cutlass_root.exists(): + rmtree(cutlass_root) + + with tarfile.open(archive_path, "r:gz") as tar: + extract_kwargs = {"path": deps_dir} + if sys.version_info >= (3, 12): + extract_kwargs["filter"] = "data" + tar.extractall(**extract_kwargs) + + extracted_dir = deps_dir / f"cutlass-{CUTLASS_VERSION}" + if not extracted_dir.exists(): + raise RuntimeError("Failed to extract CUTLASS archive") + + extracted_dir.rename(cutlass_root) + marker.touch() + return cutlass_root.resolve() + + # --------------------------- # Helpers (no torch required) # --------------------------- @@ -203,16 +244,40 @@ def _version_geq(version: str | None, major: int, minor: int = 0) -> bool: def _nvcc_release_version() -> str | None: - out = _probe_cmd(["nvcc", "--version"]) - if not out: - print( - "NVCC not found: For Ubuntu, run `sudo update-alternatives --config cuda` to fix path for already installed Cuda." - ) - return None - - match = re.search(r"release\s+(\d+)\.(\d+)", out) - if match: - return f"{match.group(1)}.{match.group(2)}" + # Search for nvcc in common locations before giving up. + candidates: list[str] = [] + nvcc_env = _read_env("NVCC") + if nvcc_env: + candidates.append(nvcc_env) + + cuda_home = _read_env("CUDA_HOME") + cuda_path = _read_env("CUDA_PATH") + + candidates.extend( + [ + "nvcc", + str(Path(cuda_home).joinpath("bin", "nvcc")) if cuda_home else None, + str(Path(cuda_path).joinpath("bin", "nvcc")) if cuda_path else None, + "/usr/local/cuda/bin/nvcc", + ] + ) + + seen = set() + for cmd in candidates: + if not cmd or cmd in seen: + continue + seen.add(cmd) + out = _probe_cmd([cmd, "--version"]) + if not out: + continue + match = re.search(r"release\s+(\d+)\.(\d+)", out) + if match: + return f"{match.group(1)}.{match.group(2)}" + + print( + "NVCC not found (checked PATH, $CUDA_HOME/bin, $CUDA_PATH/bin, /usr/local/cuda/bin). " + "For Ubuntu, run `sudo update-alternatives --config cuda` to fix path for already installed Cuda." + ) return None @@ -355,18 +420,28 @@ def _resolve_wheel_url(tag_name: str, wheel_name: str) -> str: # Fallback: default GitHub template return DEFAULT_WHEEL_URL_TEMPLATE.format(tag_name=tag_name, wheel_name=wheel_name) -# Decide HAS_CUDA_V8 without torch +# Decide HAS_CUDA_V8 / HAS_CUDA_V9 without torch HAS_CUDA_V8 = False +HAS_CUDA_V9 = False if CUDA_ARCH_LIST: - HAS_CUDA_V8 = not ROCM_VERSION and _has_cuda_v8_from_arch_list(_parse_arch_list(CUDA_ARCH_LIST)) + arch_list = _parse_arch_list(CUDA_ARCH_LIST) + try: + caps = [float(tok.split("+", 1)[0]) for tok in arch_list] + except Exception: + caps = [] + if not ROCM_VERSION: + HAS_CUDA_V8 = any(cap >= 8.0 for cap in caps) + HAS_CUDA_V9 = any(cap >= 9.0 for cap in caps) else: smi = _probe_cmd(["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"]) if smi: try: caps = [float(x.strip()) for x in smi.splitlines() if x.strip()] HAS_CUDA_V8 = any(cap >= 8.0 for cap in caps) + HAS_CUDA_V9 = any(cap >= 9.0 for cap in caps) except Exception: HAS_CUDA_V8 = False + HAS_CUDA_V9 = False if RELEASE_MODE == "1": gptqmodel_version = f"{gptqmodel_version}+{get_version_tag()}" @@ -397,6 +472,7 @@ def _env_enabled_any(names, default="1") -> bool: BUILD_MARLIN = _env_enabled_any(os.environ.get("GPTQMODEL_BUILD_MARLIN", "1")) +BUILD_MACHETE = _env_enabled(os.environ.get("GPTQMODEL_BUILD_MACHETE", "1")) BUILD_EXLLAMA_V2 = _env_enabled(os.environ.get("GPTQMODEL_BUILD_EXLLAMA_V2", "1")) BUILD_QQQ = _env_enabled(os.environ.get("GPTQMODEL_BUILD_QQQ", "1")) BUILD_AWQ = _env_enabled(os.environ.get("GPTQMODEL_BUILD_AWQ", "1")) @@ -442,6 +518,17 @@ def _env_enabled_any(names, default="1") -> bool: ], } + cutlass_root = _ensure_cutlass_source() + cutlass_include_paths = [ + Path("gptqmodel_ext/cutlass_extensions").resolve(), + cutlass_root / "include", + cutlass_root / "examples/common/include", + cutlass_root / "tools/library/include", + ] + cutlass_include_flags = [f"-I{path}" for path in cutlass_include_paths] + extra_compile_args["cxx"] += cutlass_include_flags + extra_compile_args["nvcc"] += cutlass_include_flags + # Windows/OpenMP note: adjust flags as needed for MSVC if you add native Windows wheels if sys.platform == "win32": extra_compile_args["cxx"] = ["/O2", "/std:c++17", "/openmp", "/DNDEBUG", "/DENABLE_BF16"] @@ -453,10 +540,7 @@ def _env_enabled_any(names, default="1") -> bool: if not ROCM_VERSION: # if _version_geq(NVCC_VERSION, 13, 0): # extra_compile_args["nvcc"].append("--device-entity-has-hidden-visibility=false") - extra_compile_args["nvcc"] += [ - # Allow instantiations of __global__ templates to live in different - # translation units (we split marlin kernels for Ninja parallelism). - "-static-global-template-stub=false", + nvcc_extra_flags = [ "--threads", "8", # NVCC parallelism "--optimize=3", # alias for -O3 # "-rdc=true", # enable relocatable device code, required for future cuda > 13.x <-- TODO FIX ME broken loading @@ -472,6 +556,10 @@ def _env_enabled_any(names, default="1") -> bool: # "--expt-extended-lambda", # allow device lambdas <-- not used "-diag-suppress=179,39,177", # silence some template warnings ] + if _version_geq(NVCC_VERSION, 12, 8): + # Allow instantiations of __global__ templates to live in different TUs; only supported in newer NVCC. + nvcc_extra_flags.insert(0, "-static-global-template-stub=false") + extra_compile_args["nvcc"] += nvcc_extra_flags else: # hipify CUDA-like flags def _hipify_compile_flags(flags): @@ -526,6 +614,60 @@ def _hipify_compile_flags(flags): ) ] + if BUILD_MACHETE and HAS_CUDA_V9 and _version_geq(NVCC_VERSION, 12, 0): + machete_dir = Path("gptqmodel_ext/machete") + machete_generated_dir = machete_dir / "generated" + + machete_sources = [str(machete_dir / "machete_pytorch.cu")] + machete_generated_sources = sorted(machete_generated_dir.glob("*.cu")) + + if not machete_generated_sources: + raise RuntimeError( + "No generated machete kernel templates detected. Run gptqmodel_ext/machete/generate.py" + " with CUTLASS checkout before building." + ) + + machete_sources += [str(path) for path in machete_generated_sources] + + machete_include_dirs = [str(Path("gptqmodel_ext").resolve())] + [str(path) for path in cutlass_include_paths] + + extensions += [ + cpp_ext.CUDAExtension( + "gptqmodel_machete_kernels", + machete_sources, + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + include_dirs=machete_include_dirs, + ) + ] + + if BUILD_MACHETE and HAS_CUDA_V9 and _version_geq(NVCC_VERSION, 12, 0): + machete_dir = Path("gptqmodel_ext/machete") + machete_generated_dir = machete_dir / "generated" + + machete_sources = [str(machete_dir / "machete_pytorch.cu")] + machete_generated_sources = sorted(machete_generated_dir.glob("*.cu")) + + if not machete_generated_sources: + raise RuntimeError( + "No generated machete kernel templates detected. Run gptqmodel_ext/machete/generate.py" + " with CUTLASS checkout before building." + ) + + machete_sources += [str(path) for path in machete_generated_sources] + + machete_include_dirs = [str(Path("gptqmodel_ext").resolve())] + [str(path) for path in cutlass_include_paths] + + extensions += [ + cpp_ext.CUDAExtension( + "gptqmodel_machete_kernels", + machete_sources, + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + include_dirs=machete_include_dirs, + ) + ] + if BUILD_QQQ: extensions += [ cpp_ext.CUDAExtension( @@ -583,38 +725,59 @@ def _hipify_compile_flags(flags): ] if BUILD_AWQ: - extensions += [ - # contain un-hipifiable inline PTX - cpp_ext.CUDAExtension( - "gptqmodel_awq_kernels", - [ - "gptqmodel_ext/awq/pybind_awq.cpp", - "gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu", - "gptqmodel_ext/awq/quantization/gemv_cuda.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - # TODO only compatible with ampere? - # arch_flags = get_compute_capabilities({80, 86, 89, 90}) - # extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags) - cpp_ext.CUDAExtension( - "gptqmodel_awq_v2_kernels", - [ - "gptqmodel_ext/awq/pybind_awq_v2.cpp", - "gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu", - "gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] + if ROCM_VERSION: + print("Skipping AWQ kernels on ROCm: inline PTX is CUDA-only.") + else: + extensions += [ + # contain un-hipifiable inline PTX + cpp_ext.CUDAExtension( + "gptqmodel_awq_kernels", + [ + "gptqmodel_ext/awq/pybind_awq.cpp", + "gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu", + "gptqmodel_ext/awq/quantization/gemv_cuda.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ), + # TODO only compatible with ampere? + # arch_flags = get_compute_capabilities({80, 86, 89, 90}) + # extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags) + cpp_ext.CUDAExtension( + "gptqmodel_awq_v2_kernels", + [ + "gptqmodel_ext/awq/pybind_awq_v2.cpp", + "gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu", + "gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ), + ] + + # Ensure machete kernels are compiled before other extensions + machete_exts = [ext for ext in extensions if getattr(ext, "name", "") == "gptqmodel_machete_kernels"] + if machete_exts: + other_exts = [ext for ext in extensions if getattr(ext, "name", "") != "gptqmodel_machete_kernels"] + extensions[:] = machete_exts + other_exts + + # additional_setup_kwargs = { + # "ext_modules": extensions, + # "cmdclass": {"build_ext": cpp_ext.BuildExtension}, + # } additional_setup_kwargs = { "ext_modules": extensions, - "cmdclass": {"build_ext": cpp_ext.BuildExtension}, + "cmdclass": {"build_ext": cpp_ext.BuildExtension.with_options( + use_ninja=True, + no_python_abi_suffix=True, + build_temp="build/temp", + build_lib="build/lib", + clean_first=False # keep intermediates for reuse + )}, } + # --------------------------- # Cached wheel fetcher # --------------------------- @@ -658,6 +821,7 @@ def run(self): # --------------------------- print(f"CUDA {CUDA_ARCH_LIST}") print(f"HAS_CUDA_V8 {HAS_CUDA_V8}") +print(f"HAS_CUDA_V9 {HAS_CUDA_V9}") print(f"SETUP_KWARGS {additional_setup_kwargs}") print(f"gptqmodel_version={gptqmodel_version}") diff --git a/tests/test_awq.py b/tests/test_awq.py index 6a02ccfef..1afa83198 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -10,6 +10,7 @@ import tempfile import unittest +import torch from datasets import load_dataset from parameterized import parameterized from transformers import AutoTokenizer @@ -17,8 +18,10 @@ from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear +from gptqmodel.nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME +from gptqmodel.utils.machete import _validate_machete_device_support, machete_import_exception from gptqmodel.utils.torch import torch_empty_cache @@ -35,15 +38,82 @@ class TestGroupSize(unittest.TestCase): @classmethod - def setUpClass(self): - self.pretrained_model_id = "/monster/data/model/Llama-3.2-1B" + def setUpClass(cls): + cls.pretrained_model_id = "/monster/data/model/Llama-3.2-1B" # "/monster/data/model/Qwen2.5-0.5B-Instruct/" "/monster/data/model/Qwen2.5-0.5B-Instruct/" # - self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_id, use_fast=True) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.pretrained_model_id, use_fast=True) + + requested_samples = os.getenv("GPTQMODEL_AWQ_CALIB_SAMPLES") + if requested_samples is not None: + sample_count = max(1, int(requested_samples)) + else: + total_mem_gb = 0 + if torch.cuda.is_available(): + try: + total_mem_gb = ( + torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory + / (1024 ** 3) + ) + except Exception: + total_mem_gb = 0 + + if total_mem_gb >= 80: + sample_count = 1024 + elif total_mem_gb >= 48: + sample_count = 512 + else: + sample_count = 192 traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") - self.calibration_dataset = traindata.select(range(1024)) + cls.calibration_dataset = traindata.select(range(sample_count)) + + cls.quantized_tempdirs = {} + cls.quantized_model_paths = {} + cls.quantize_config_dicts = {} + + quantize_targets = { + (FORMAT.GEMM, 128), + (FORMAT.GEMV, 128), + (FORMAT.GEMV_FAST, 128), + } + + for checkpoint_format, group_size in quantize_targets: + quantize_config = QuantizeConfig( + bits=4, + group_size=group_size, + quant_method=METHOD.AWQ, + format=checkpoint_format, + ) + + model = GPTQModel.load( + cls.pretrained_model_id, + quantize_config=quantize_config, + ) + model.quantize(cls.calibration_dataset, batch_size=1, calibration_concat_size=0) + + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir_name = tmp_dir.name + model.save(tmp_dir_name) + + with open(tmp_dir_name + "/" + QUANT_CONFIG_FILENAME, "r") as f: + file_dict = json.loads(f.read()) + assert model.quantize_config.to_dict() == file_dict + logging.info(f"Saved config file: {file_dict}") + + cls.quantized_tempdirs[(checkpoint_format, group_size)] = tmp_dir + cls.quantized_model_paths[(checkpoint_format, group_size)] = tmp_dir_name + cls.quantize_config_dicts[(checkpoint_format, group_size)] = file_dict + + del model + # torch_empty_cache() + + @classmethod + def tearDownClass(cls): + for tmp_dir in getattr(cls, "quantized_tempdirs", {}).values(): + tmp_dir.cleanup() + # torch_empty_cache() # def test_load_group_128(self): # model = GPTQModel.load( @@ -57,54 +127,47 @@ def setUpClass(self): @parameterized.expand([ (FORMAT.GEMM, BACKEND.GEMM, 128), + (FORMAT.GEMM, BACKEND.MACHETE, 128), (FORMAT.GEMM, BACKEND.MARLIN, 128), (FORMAT.GEMV, BACKEND.GEMV, 128), (FORMAT.GEMV_FAST, BACKEND.GEMV_FAST, 128), ]) def test_quant_and_inference(self, checkpoint_format, backend, group_size: int): - quantize_config = QuantizeConfig( - bits=4, - group_size=group_size, - quant_method=METHOD.AWQ, - format=checkpoint_format, - ) + if backend == BACKEND.MACHETE: + if machete_import_exception is not None: + self.skipTest(f"machete unavailable: {machete_import_exception}") + if not _validate_machete_device_support(): + self.skipTest("Machete requires NVIDIA Hopper or newer (SM90+)") + + key = (checkpoint_format, group_size) + model_path = self.quantized_model_paths[key] + expected_config = self.quantize_config_dicts[key] model = GPTQModel.load( - self.pretrained_model_id, - quantize_config=quantize_config, + model_path, + backend=backend, ) - model.quantize(self.calibration_dataset, batch_size=1, calibration_concat_size=0) - with tempfile.TemporaryDirectory() as tmp_dir_name: - model.save(tmp_dir_name) - - with open(tmp_dir_name + "/" + QUANT_CONFIG_FILENAME, "r") as f: - file_dict = json.loads(f.read()) - # make sure the json dict saved to file matches config in memory - assert model.quantize_config.to_dict() == file_dict - logging.info(f"Saved config file: {file_dict}") + self.assertEqual(model.quantize_config.to_dict(), expected_config) - del model - torch_empty_cache() - - model = GPTQModel.load( - tmp_dir_name, - backend=backend, - ) + self.assert_awq_linear(model, backend) - self.assert_awq_linear(model, backend) + tokens = model.generate("Capital of France is", max_new_tokens=100)[0] + result = model.tokenizer.decode(tokens) + print(f"BACKEND: {backend}, Result: {result}") + if "paris" not in result.lower() and "city" not in result.lower(): + raise AssertionError(" `paris` not found in `result`") - tokens = model.generate("Capital of France is", max_new_tokens=100)[0] - result = model.tokenizer.decode(tokens) - print(f"BACKEND: {BACKEND.GEMM}, Result: {result}") - if "paris" not in result.lower() and "city" not in result.lower(): - raise AssertionError(" `paris` not found in `result`") + del model + # torch_empty_cache() def assert_awq_linear(self, model, backend): has_qqq = False for _, module in model.named_modules(): if backend == BACKEND.GEMM: linear = AwqGEMMQuantLinear + elif backend == BACKEND.MACHETE: + linear = AwqMacheteQuantLinear elif backend == BACKEND.MARLIN: linear = AwqMarlinQuantLinear elif backend == BACKEND.GEMV: diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index e33bef026..0469f2763 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -5,6 +5,7 @@ import os import unittest +from typing import List, Tuple import torch from logbar import LogBar @@ -15,9 +16,14 @@ from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear +from gptqmodel.nn_modules.qlinear.machete import MacheteQuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.utils.machete import ( + _validate_machete_device_support, + machete_import_exception, +) from gptqmodel.utils.model import find_modules @@ -42,29 +48,17 @@ class TestKernelOutput(unittest.TestCase): # model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" model_path = "sliuau/Llama-3.2-3B_4bits_128group_size" target_qliner_map = { - # BACKEND.EXLLAMA_V1: ExllamaQuantLinear, - # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, + BACKEND.TORCH: TorchQuantLinear, + BACKEND.MACHETE: MacheteQuantLinear, BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.TORCH: TorchQuantLinear, - # BACKEND.TORCH_FUSED: TorchFusedQuantLinear, BACKEND.BITBLAS: BitblasQuantLinear, - # BACKEND.IPEX: IPEXQuantLinear, BACKEND.MARLIN: MarlinQuantLinear, - # BACKEND.MARLIN_FP16: MarlinQuantLinear, } target = 'model.layers.6.self_attn.v_proj' - m = [ # tuple is dim_0 size and num_sampes for each dim_0 - (1, 256), - (16, 128), - (32, 64), - (64, 32), - (128, 16), - ] - - # sum all the second tuple value for total sample size - random_input_sample_size = sum(t[1] for t in m) + m: List[Tuple[int, int]] = [] + random_input_sample_size = 0 @classmethod @@ -76,6 +70,45 @@ def setUpClass(cls): test_dtypes = [torch.float16, torch.bfloat16] cls.data = {} # key is dtype, v is Data() + def _parse_shapes(expr: str) -> List[Tuple[int, int]]: + shapes: List[Tuple[int, int]] = [] + for part in expr.split(","): + part = part.strip() + if not part: + continue + dim_str, samples_str = part.split(":", 1) + shapes.append((int(dim_str), int(samples_str))) + return shapes + + large_shapes = [(1, 256), (16, 128), (32, 64), (64, 32), (128, 16)] + medium_shapes = [(1, 128), (16, 64), (32, 32), (64, 16)] + small_shapes = [(1, 64), (8, 32), (16, 16)] + + env_shapes = os.getenv("GPTQMODEL_KERNEL_TEST_SHAPES") + if env_shapes: + cls.m = _parse_shapes(env_shapes) + else: + total_mem_gb = 0.0 + if torch.cuda.is_available(): + device_index = DEVICE.index if DEVICE.index is not None else 0 + try: + if torch.cuda.device_count() > device_index: + props = torch.cuda.get_device_properties(device_index) + total_mem_gb = props.total_memory / (1024 ** 3) + except Exception: + total_mem_gb = 0.0 + + if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": + cls.m = small_shapes + elif total_mem_gb >= 80: + cls.m = large_shapes + elif total_mem_gb >= 48: + cls.m = medium_shapes + else: + cls.m = small_shapes + + cls.random_input_sample_size = sum(t[1] for t in cls.m) + for dtype in test_dtypes: data = Data() @@ -187,20 +220,28 @@ def _summarize_results( f"{len(failures)} mismatched outputs for backend `{backend}` and dtype `{dtype}`" ) - @parameterized.expand([ + def _maybe_skip_backend(self, backend: BACKEND): + if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + + if backend == BACKEND.MACHETE: + if machete_import_exception is not None: + self.skipTest(f"Machete kernel unavailable: {machete_import_exception}") + if not _validate_machete_device_support(): + self.skipTest("Machete requires NVIDIA Hopper or newer (SM90+)") + + float16_cases = [ (BACKEND.TORCH, torch.float16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.float16, 0.0001), (BACKEND.TRITON, torch.float16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.float16, 0.0050), (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), + (BACKEND.MACHETE, torch.float16, 0.00040), (BACKEND.MARLIN, torch.float16, 0.00035), (BACKEND.BITBLAS, torch.float16, 0.0035), - # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), - # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0025), - ]) + ] + + @parameterized.expand(float16_cases) def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": - self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + self._maybe_skip_backend(backend) data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype) @@ -215,20 +256,18 @@ def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance reference_label="Torch output", ) - @parameterized.expand([ + bfloat16_cases = [ (BACKEND.TORCH, torch.bfloat16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), (BACKEND.TRITON, torch.bfloat16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0064), (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), + (BACKEND.MACHETE, torch.bfloat16, 0.0033), (BACKEND.MARLIN, torch.bfloat16, 0.0031), (BACKEND.BITBLAS, torch.bfloat16, 0.0031), - # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.012), - # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0031), TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 - ]) + ] + + @parameterized.expand(bfloat16_cases) def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": - self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + self._maybe_skip_backend(backend) data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype) @@ -243,20 +282,18 @@ def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance reference_label="Torch output", ) - @parameterized.expand([ + float16_lora_cases = [ (BACKEND.TORCH, torch.float16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.float16, 0.0001), (BACKEND.TRITON, torch.float16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.float16, 0.0054), (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), + (BACKEND.MACHETE, torch.float16, 0.00040), (BACKEND.MARLIN, torch.float16, 0.00035), (BACKEND.BITBLAS, torch.float16, 0.00035), - # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), - # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0020) - ]) + ] + + @parameterized.expand(float16_lora_cases) def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": - self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + self._maybe_skip_backend(backend) data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) @@ -270,21 +307,18 @@ def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_ reference_label="Torch with Lora output", ) - - @parameterized.expand([ + bfloat16_lora_cases = [ (BACKEND.TORCH, torch.bfloat16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), (BACKEND.TRITON, torch.bfloat16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0062), (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), - (BACKEND.MARLIN, torch.bfloat16, 0.0033), + (BACKEND.MACHETE, torch.bfloat16, 0.0033), + (BACKEND.MARLIN, torch.bfloat16, 0.0050), (BACKEND.BITBLAS, torch.bfloat16, 0.0033), - # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.011), - # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0014) TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 - ]) + ] + + @parameterized.expand(bfloat16_lora_cases) def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - if backend == BACKEND.BITBLAS and os.getenv("RUN_BITBLAS_TESTS", "0") != "1": - self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") + self._maybe_skip_backend(backend) data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter)