From 9f44ae005bc7c3e2846165ff1852f292649fa269 Mon Sep 17 00:00:00 2001 From: "Ayoub G." Date: Mon, 9 Mar 2026 02:31:03 -0600 Subject: [PATCH 1/2] adding INT8 MMA support on SM80+ --- .../python/CuTeDSL/ampere/tensorop_gemm_i8.py | 838 ++++++++++++++++++ .../cutlass/cute/nvgpu/warp/__init__.py | 2 + python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py | 102 +++ 3 files changed, 942 insertions(+) create mode 100644 examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py b/examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py new file mode 100644 index 0000000000..d531fabdde --- /dev/null +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py @@ -0,0 +1,838 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import math +from typing import Tuple, Type + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +from cutlass.cute.runtime import from_dlpack + +""" +An INT8 dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL. +- Matrix A is MxKxL (INT8/UINT8), L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL (INT8/UINT8), L is batch dimension, B can be row-major("K") or column-major("N") +- Matrix C is MxNxL (INT32), L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Ampere's INT8 tensor cores (mma.sync m16n8k16 or m16n8k32) + - Supports signedness: S8xS8 (S8xU8 supported by hardware but blocked by + DSL compiler bug, U8xS8 not supported by hardware) + - Multi-stage pipeline (3 stages) to overlap computation and memory access + - cp.async for global-to-shared memory transfers + - Swizzled shared memory layout for bank-conflict-free access + - Supports arbitrary M, N, K dimensions (padded to tile boundaries) + +To run this example: + +.. code-block:: bash + + python examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py \\ + --mnkl 512,512,512,1 --atom_layout_mnk 2,2,1 \\ + --a_dtype Int8 --b_dtype Int8 --c_dtype Int32 --acc_dtype Int32 \\ + --a_major k --b_major k --c_major n + +Constraints: +* Supported input types: Int8, Uint8 +* Supported accumulator/output type: Int32 +* Default tile shape: 128x128x64 (BM auto-selected based on M) +* A and B must be K-major (row-major): ldmatrix requires 128-bit aligned smem + addresses, but column-major INT8 only provides 64-bit alignment +* Output C can be N-major (row-major) or M-major (column-major) +* The contiguous dimension must be at least 16 bytes aligned +""" + + +class TensorOpGemmI8: + def __init__( + self, + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: Tuple[int, int, int], + use_k32: bool = False, + bm: int = 128, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.c_dtype = c_dtype + self.acc_dtype = acc_dtype + self.use_k32 = use_k32 + self.cta_tiler = (bm, 128, 64) + self.num_stages = 3 + self.atom_layout_mnk = atom_layout_mnk + atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk + self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32 + + self.bM, self.bN, self.bK = self.cta_tiler + self.mma_inst_shape = (16, 8, 32) if use_k32 else (16, 8, 16) + mmaM, mmaN, mmaK = self.mma_inst_shape + + # PTX mma.sync constraint: if mixed signedness, A must be signed, B unsigned. + # A=Uint8, B=Int8 is NOT supported by the hardware. + if not self.a_dtype.signed and self.b_dtype.signed: + raise ValueError( + "A=Uint8, B=Int8 is not supported by the MMA instruction. " + "Use A=Int8, B=Uint8 or same signedness for both." + ) + + assert self.bM % (atom_lay_M * mmaM) == 0, ( + "bM must be divisible by MMA instruction" + ) + assert self.bN % (atom_lay_N * mmaN) == 0, ( + "bN must be divisible by MMA instruction" + ) + assert atom_lay_K == 1, "this example does not support atom layout K > 1" + assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction" + assert self.num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout + # /////////////////////////////////////////////////////////////////////////////// + ab_copy_bits = 128 + sA_layout = self._make_smem_layout_AB( + mA.element_type, + self.a_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[2], self.num_stages), + ) + sB_layout = self._make_smem_layout_AB( + mB.element_type, + self.b_major_mode, + ab_copy_bits, + (self.cta_tiler[1], self.cta_tiler[2], self.num_stages), + ) + + smem_size = ( + cute.size_in_bytes(mA.element_type, sA_layout) + + cute.size_in_bytes(mB.element_type, sB_layout) + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled copy: global memory -> shared memory (asynchronous cp.async) + # /////////////////////////////////////////////////////////////////////////////// + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + mA.element_type, + num_bits_per_copy=ab_copy_bits, + ) + + tiled_copy_A = self._make_gmem_tiled_copy_AB( + atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits + ) + tiled_copy_B = self._make_gmem_tiled_copy_AB( + atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled MMA + # /////////////////////////////////////////////////////////////////////////////// + op = cute.nvgpu.warp.MmaI8Op( + self.a_dtype, self.b_dtype, self.acc_dtype, self.mma_inst_shape + ) + + permutation_mnk = ( + self.atom_layout_mnk[0] * self.mma_inst_shape[0], + self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + self.atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + + tC = cute.make_layout(self.atom_layout_mnk) + tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + + raster_factor = 1 + grid_dim_n = cute.size(grid_dim[1]) + if grid_dim_n > 5: + raster_factor = 8 + elif grid_dim_n > 2: + raster_factor = 4 + elif grid_dim_n > 1: + raster_factor = 2 + rasterization_remap_grid_dim = ( + cute.size(grid_dim[0]) * raster_factor, + (cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor, + cute.size(grid_dim[2]), + ) + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + tiled_copy_A, + tiled_copy_B, + tiled_mma, + raster_factor, + epilogue_op, + ).launch( + grid=rasterization_remap_grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.ComposedLayout, + sB_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_mma: cute.TiledMma, + rasterization_factor: cutlass.Int32, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + offset_tile_x, offset_tile_y = self.raster_tile( + bidx, bidy, rasterization_factor + ) + # Early exit if CTA is out of range + if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y: + pass + else: + tiler_coord = (offset_tile_x, offset_tile_y, None) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_K, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + gB = cute.local_tile( + mB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + gC = cute.local_tile( + mC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + + # Make the first k-tiles irregular instead of last, so we handle + # the residual tile first and avoid a condition in the mainloop. + residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( + gA, mode=[2] + ) + + gA = cute.domain_offset((0, residual_k, 0), gA) + gB = cute.domain_offset((0, residual_k, 0), gB) + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + # Construct identity layout for predication + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile( + mcA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + cB = cute.local_tile( + mcB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + + cA = cute.domain_offset((0, residual_k, 0), cA) + cB = cute.domain_offset((0, residual_k, 0), cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffers and thread partitions. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predication: mark indices that need copies when M/N/K aren't + # multiples of tile shape. M/N predicates are stored in tensors; + # K residual is handled by the domain_offset + if/else branch. + # /////////////////////////////////////////////////////////////////////////////// + tApA = cute.make_rmem_tensor( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_rmem_tensor( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch prologue + # /////////////////////////////////////////////////////////////////////////////// + tAsA.fill(0) + tBsB.fill(0) + cute.arch.sync_threads() + num_smem_stages = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + k_tile_index = cutlass.Int32(0) + + for k in range(tApA.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): + cute.copy( + tiled_copy_A, + tAgA[None, None, k, k_tile_index], + tAsA[None, None, k, 0], + pred=tApA[None, None, k], + ) + for k in range(tBpB.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): + cute.copy( + tiled_copy_B, + tBgB[None, None, k, k_tile_index], + tBsB[None, None, k, 0], + pred=tBpB[None, None, k], + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + for k_tile in range(1, num_smem_stages - 1): + if k_tile == k_tile_count: + tApA.fill(0) + tBpB.fill(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + tCrC.fill(0) + + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling (shared memory -> registers) + # /////////////////////////////////////////////////////////////////////////////// + # S2R copy atoms for INT8 + # Use LdMatrix8x8x16bOp: ldmatrix operates in 16-bit register + # units regardless of element type; each 16b value holds 2 INT8 + # elements. The transpose depends on the majorness of the operand. + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mA.element_type, + ) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mB.element_type, + ) + + tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma) + tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma) + + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + smem_pipe_read = 0 + smem_pipe_write = num_smem_stages - 1 + + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + + # Register pipeline prefetch + num_k_block = cute.size(tCrA, mode=[2]) + if num_k_block > 1: + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop: interleaved smem pipeline (gmem->smem) and register + # pipeline (smem->rmem) with MMA computation. + # /////////////////////////////////////////////////////////////////////////////// + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + k_block_next = (k_block + 1) % num_k_block + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, smem_pipe_write], + pred=tApA, + ) + + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, smem_pipe_write], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == num_smem_stages: + smem_pipe_read = 0 + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue: direct register to global memory (no smem buffer). + # Unlike the FP16 tensorop_gemm which buffers through smem for + # predicated stores, we skip smem to avoid a 64KB INT32 buffer + # (128x128x4B) that would limit occupancy to 1 CTA/SM. Instead, + # the run() function pads output tensors to tile boundaries. + # /////////////////////////////////////////////////////////////////////////////// + tCrD = cute.make_fragment_like(tCrC, self.c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) + cute.autovec_copy(tCrD, tCgC) + return + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + # MBase: log2 of elements per 128-bit copy + # INT8: 128/8 = 16 elems -> MBase=4, F16: 128/16 = 8 elems -> MBase=3 + mbase = int(math.log2(copy_bits // dtype.width)) + # SShift must differ from MBase for actual swizzling (SShift==MBase is no-op) + # Swizzle<2,4,3> is correct for INT8: 16-byte contiguity, 8-row atom + sshift = 3 + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, mbase, sshift), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2)) + return layout + + def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bK) // copy_elems + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def raster_tile(self, i, j, f): + new_i = i // f + new_j = (i % f) + (j * f) + return (new_i, new_j) + + +def run( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + mnkl: Tuple[int, int, int, int], + atom_layout_mnk: Tuple[int, int, int] = None, + use_k32: bool = False, + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + warmup_iterations: int = 2, + iterations: int = 100, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + import torch + import cutlass.torch as cutlass_torch + + M, N, K, L = mnkl + + # Auto-select tile size and atom layout based on M + # (2,4,1)=8 warps best for M>=64 (fewer accum regs) + # (2,2,1)=4 warps best for large M (2 CTAs/SM, better latency hiding) + # (2,4,1) requires bm>=64 for correct gmem copy thread layout + if M <= 16: + bm, default_atom = 16, (1, 2, 1) + elif M <= 32: + bm, default_atom = 32, (2, 2, 1) + elif M <= 64: + bm, default_atom = 64, (2, 4, 1) + elif M <= 256: + bm, default_atom = 128, (2, 4, 1) + else: + bm, default_atom = 128, (2, 2, 1) + if atom_layout_mnk is None: + atom_layout_mnk = default_atom + + # INT8 requires K-major (row-major) inputs: ldmatrix needs 128-bit aligned + # smem addresses, but column-major INT8 only gives 64-bit alignment. + if a_major != "k": + raise ValueError( + f"A must be K-major (row-major) for INT8. Got a_major='{a_major}'. " + "Column-major INT8 does not meet ldmatrix 128-bit alignment." + ) + if b_major != "k": + raise ValueError( + f"B must be K-major (row-major) for INT8. Got b_major='{b_major}'. " + "Column-major INT8 does not meet ldmatrix 128-bit alignment." + ) + + # Pad dimensions to tile boundaries for the unpredicated epilogue. + # Input loads use predicates, but epilogue stores don't check bounds. + bN = 128 + bK = 64 + M_pad = ((M + bm - 1) // bm) * bm + N_pad = ((N + bN - 1) // bN) * bN + K_pad = ((K + bK - 1) // bK) * bK + + print("Running Ampere INT8 tensor core GEMM example:") + print(f"mnkl: {mnkl}, bm: {bm}, padded: ({M_pad}, {N_pad}, {K_pad})") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Use k32 MMA: {use_k32}") + print(f"Atoms layout: {atom_layout_mnk}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + + # A(M,K) K-major, B(N,K) K-major, C(M,N) N-major + def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + torch_dtype = cutlass_torch.dtype(dtype) + if dtype.signed: + torch_tensor = torch.randint(-2, 3, shape, dtype=torch_dtype) + else: + torch_tensor = torch.randint(0, 5, shape, dtype=torch_dtype) + torch_tensor = torch_tensor.permute(permute_order).cuda() + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0)) + .mark_compact_shape_dynamic( + mode=(1 if not is_mode0_major else 0), + stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor + + # Create padded tensors; zero padding ensures MMA gets zeros for OOB elements. + # Tensor shape after permute is always (mode0, mode1, L); majorness only + # affects strides, so padding indexing is the same for any majorness. + mA, a_torch = create_and_permute_tensor(L, M_pad, K_pad, a_major == "m", a_dtype) + mB, b_torch = create_and_permute_tensor(L, N_pad, K_pad, b_major == "n", b_dtype) + mC, c_torch = create_and_permute_tensor(L, M_pad, N_pad, c_major == "m", c_dtype) + # Zero the padding regions of input tensors + if M_pad > M: + a_torch[M:, :, :] = 0 + if K_pad > K: + a_torch[:, K:, :] = 0 + b_torch[:, K:, :] = 0 + if N_pad > N: + b_torch[N:, :, :] = 0 + + tensor_op_gemm = TensorOpGemmI8( + a_dtype, + b_dtype, + c_dtype, + acc_dtype, + atom_layout_mnk, + use_k32, + bm, + ) + + print("Compiling kernel with cute.compile ...") + compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC) + + print("Executing GEMM kernel...") + + if not skip_ref_check: + # torch.einsum doesn't support int on CUDA, compute on CPU + # Use only the valid (non-padded) region for reference + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch[:M, :K, :].cpu().to(dtype=torch.int32), + b_torch[:N, :K, :].cpu().to(dtype=torch.int32), + ).to(cutlass_torch.dtype(c_dtype)) + compiled_gemm(mA, mB, mC) + print("Verifying results...") + torch.testing.assert_close(c_torch[:M, :N, :].cpu(), ref, atol=0, rtol=0) + print("Results verified successfully!") + + def generate_tensors(): + a_workspace, _ = create_and_permute_tensor(L, M_pad, K_pad, a_major == "m", a_dtype) + b_workspace, _ = create_and_permute_tensor(L, N_pad, K_pad, b_major == "n", b_dtype) + c_workspace, _ = create_and_permute_tensor(L, M_pad, N_pad, c_major == "m", c_dtype) + return testing.JitArguments(a_workspace, b_workspace, c_workspace) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cuda_graphs=False, + ) + + return avg_time_us + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="INT8 tensor core GEMM with CuTe DSL on Ampere GPU" + ) + parser.add_argument( + "--mnkl", type=parse_comma_separated_ints, default=(128, 128, 128, 1) + ) + parser.add_argument( + "--atom_layout_mnk", type=parse_comma_separated_ints, default=None + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + choices=[cutlass.Int8, cutlass.Uint8], + default=cutlass.Int8, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + choices=[cutlass.Int8, cutlass.Uint8], + default=cutlass.Int8, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + choices=[cutlass.Int32], + default=cutlass.Int32, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + choices=[cutlass.Int32], + default=cutlass.Int32, + ) + parser.add_argument("--use_k32", action="store_true", default=False) + parser.add_argument("--a_major", choices=["k", "m"], default="k") + parser.add_argument("--b_major", choices=["k", "n"], default="k") + parser.add_argument("--c_major", choices=["n", "m"], default="n") + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + run( + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.mnkl, + args.atom_layout_mnk, + args.use_k32, + args.a_major, + args.b_major, + args.c_major, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py index 6d1e30344b..afdbcc83dc 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -18,6 +18,8 @@ # mma.py "Field", "MmaF16BF16Op", + "MmaI8Op", + "MmaIntOverflow", "MmaMXF4Op", "MmaMXF4NVF4Op", # copy.py diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index 781128b640..134ad34192 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -27,6 +27,9 @@ Float16, BFloat16, Float32, + Int8, + Uint8, + Int32, Boolean, Numeric, Pointer, @@ -118,6 +121,105 @@ class MmaF16BF16Trait(Trait): pass +class MmaIntOverflow(enum.Enum): + """Integer overflow mode for warp-level INT8 MMA operations.""" + + SATURATE = "satfinite" + WRAP = "wrapped" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +@dataclass(frozen=True) +class MmaI8Op(WarpMmaOp): + """ + INT8 warp-level MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.s8`` or ``.u8`` qualifiers for the input operands, + with ``.s32`` accumulator. + + Supported shapes: (16,8,16) and (16,8,32). + Supports mixed signedness (e.g. signed A x unsigned B). + """ + + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + overflow: MmaIntOverflow = MmaIntOverflow.SATURATE + + def __post_init__(self) -> None: + if self.a_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of Int8 or Uint8", + ) + if self.b_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of Int8 or Uint8", + ) + if self.acc_dtype != Int32: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Int32", + ) + if self.shape_mnk not in [(16, 8, 16), (16, 8, 32)]: + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be one of (16,8,16) or (16,8,32)", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + overflow_attr = ir.Attribute.parse( + f"#cute_nvgpu.mma_int_overflow<{self.overflow.value}>" + ) + a_mlir_type = ( + ir.IntegerType.get_signed(8) + if self.a_dtype.signed + else ir.IntegerType.get_unsigned(8) + ) + b_mlir_type = ( + ir.IntegerType.get_signed(8) + if self.b_dtype.signed + else ir.IntegerType.get_unsigned(8) + ) + ty = _cute_nvgpu_ir.MmaAtomSM80Type.get( + shape_mnk.type.attribute, + a_mlir_type, + b_mlir_type, + self.acc_dtype.mlir_type, + intOverflow=overflow_attr, + ) + return MmaI8Trait(make_atom(ty, loc=loc, ip=ip)) + + def __str__(self) -> str: + return ( + "warp-level INT8 MMA Operation" + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + + f"\n Overflow mode = {self.overflow}" + ) + + def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): + pass + + +class MmaI8Trait(Trait): + pass + + # Base class for SM120 Blockscaled MMA Ops @dataclass(frozen=True) class MmaSM120BlockScaledOp(MmaOp): From 382200f18f8a1f906799dde38afa7ad2840dcea5 Mon Sep 17 00:00:00 2001 From: "Ayoub G." Date: Mon, 9 Mar 2026 10:20:09 -0600 Subject: [PATCH 2/2] adding int gemm to cute examples --- examples/cute/tutorial/igemm_sm80.cu | 532 +++++++++++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 examples/cute/tutorial/igemm_sm80.cu diff --git a/examples/cute/tutorial/igemm_sm80.cu b/examples/cute/tutorial/igemm_sm80.cu new file mode 100644 index 0000000000..722594c186 --- /dev/null +++ b/examples/cute/tutorial/igemm_sm80.cu @@ -0,0 +1,532 @@ +/*************************************************************************************************** + * INT8 Tensor Core GEMM using CuTe on SM80+ + * + * C(M,N) = A(M,K) * B(N,K)^T with A, B K-major (row-major), C M-major (col-major) + * Accumulator: int32, Output: int32 + * + * Features: + * - SM80 m16n8k16 / m16n8k32 INT8 MMA atoms + * - Swizzled shared memory (Swizzle<2,4,3>) for zero bank conflicts + * - 3-stage cp.async pipeline with interleaved A/B loads + * - CTA swizzle for L2 locality + * - Predicated loads (ZFILL) for arbitrary M, N, K + * - Configurable tile sizes (BM, BN) and MMA K-size (k16/k32) + * + **************************************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +using namespace cute; + +// ─── Shared memory storage ──────────────────────────────────────────────── +template +struct SharedStorage { + ArrayEngine> A; + ArrayEngine> B; +}; + +// ─── Predicate helper ───────────────────────────────────────────────────── +// Build a (CPY_M, CPY_K) bool tensor from a partitioned coordinate tensor +template +__device__ auto make_pred_2d(CoordTensor const& coords, int row_max, int k_lim) { + auto pred = make_tensor(make_shape(size<1>(coords), size<2>(coords))); + CUTE_UNROLL + for (int i = 0; i < size<0>(pred); ++i) { + CUTE_UNROLL + for (int j = 0; j < size<1>(pred); ++j) { + pred(i,j) = get<0>(coords(Int<0>{},i,j)) < row_max && + get<1>(coords(Int<0>{},i,j)) < k_lim; + } + } + return pred; +} + +// ─── GEMM kernel ────────────────────────────────────────────────────────── +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, S2RAtomA s2r_atom_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, S2RAtomB s2r_atom_b, + TC * C, CStride dC, CSmemLayout , TiledMma mma, + Alpha alpha, Beta beta) +{ + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); + + // Full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // CTA swizzle: remap blockIdx for L2 locality + int tile_m = blockIdx.x, tile_n = blockIdx.y; + if constexpr (SwizzleLog > 0) { + constexpr int SW = 1 << SwizzleLog; + int tiles_m = gridDim.x; + int tiles_n = gridDim.y; + if (tiles_n >= SW) { + int bid = blockIdx.x + blockIdx.y * tiles_m; + int panel_sz = tiles_m * SW; + int panel_id = bid / panel_sz; + int within = bid % panel_sz; + tile_m = within / SW; + tile_n = panel_id * SW + within % SW; + if (tile_n >= tiles_n) return; + } + } + auto cta_coord = make_coord(tile_m, tile_n, _); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Coordinate tensors for bounds-checking + Tensor cA = make_identity_tensor(make_shape(size<0>(cta_tiler), size<2>(cta_tiler))); + Tensor cB = make_identity_tensor(make_shape(size<1>(cta_tiler), size<2>(cta_tiler))); + Tensor cC = make_identity_tensor(make_shape(size<0>(cta_tiler), size<1>(cta_tiler))); + + int m_max = get<0>(shape_MNK) - tile_m * int(size<0>(cta_tiler)); + int n_max = get<1>(shape_MNK) - tile_n * int(size<1>(cta_tiler)); + int k_max = get<2>(shape_MNK); + int bk = int(size<2>(cta_tiler)); + + // Shared memory buffers + extern __shared__ char shared_memory[]; + using SmemStorage = SharedStorage; + SmemStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) + + // + // Partition the copying of A and B tiles across the threads + // + + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) + + // + // Define A/B partitioning and C accumulators + // + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + clear(tCrC); + Tensor tCcC = thr_mma.partition_C(cC); // epilogue coords + + // + // Copy Atom retiling + // + + TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); + ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); + Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) + Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) + + TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); + ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); + Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) + Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) + + // + // PREFETCH + // + + auto K_PIPE_MAX = size<3>(tAsA); + int k_tile_count = size<3>(tAgA); + int k_tile_next = 0; + + // Interior CTA: all M/N/K elements in-bounds -> skip predication + bool is_interior = (m_max >= int(size<0>(cta_tiler))) && + (n_max >= int(size<1>(cta_tiler))) && + (k_max % bk == 0); + + // Start async loads for all pipes but the last + CUTE_UNROLL + for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) { + if (is_interior) { + copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe)); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe)); + } else { + int k_lim = k_max - k_tile_next * bk; + copy_if(copy_a, make_pred_2d(tAcA, m_max, k_lim), tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe)); + copy_if(copy_b, make_pred_2d(tBcB, n_max, k_lim), tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe)); + } + cp_async_fence(); + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_next; } + } + + // + // PIPELINED MAIN LOOP + // gmem(k_tile_next) -> smem(smem_pipe_write) + // smem(smem_pipe_read) -> rmem(k_block_next) + // compute on rmem(k_block) + // + + auto K_BLOCK_MAX = size<2>(tCrA); + int smem_pipe_read = 0; + int smem_pipe_write = K_PIPE_MAX-1; + + Tensor tXsA_p = tXsA(_,_,_,smem_pipe_read); + Tensor tXsB_p = tXsB(_,_,_,smem_pipe_read); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + cp_async_wait(); + __syncthreads(); + copy(s2r_atom_a, tXsA_p(_,_,Int<0>{}), tXrA(_,_,Int<0>{})); + copy(s2r_atom_b, tXsB_p(_,_,Int<0>{}), tXrB(_,_,Int<0>{})); + } + + CUTE_NO_UNROLL + while (k_tile_count > -(K_PIPE_MAX-1)) + { + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tXsA_p = tXsA(_,_,_,smem_pipe_read); + tXsB_p = tXsB(_,_,_,smem_pipe_read); + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; + copy(s2r_atom_a, tXsA_p(_,_,k_block_next), tXrA(_,_,k_block_next)); + copy(s2r_atom_b, tXsB_p(_,_,k_block_next), tXrB(_,_,k_block_next)); + + // Interleaved copy: A before gemm, B after gemm + if (k_block == 0) + { + if (is_interior) { + copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write)); + } else { + int k_lim = k_max - k_tile_next * bk; + copy_if(copy_a, make_pred_2d(tAcA, m_max, k_lim), tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write)); + } + } + + // Thread-level register gemm for k_block + gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + + if (k_block == 0) + { + if (is_interior) { + copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write)); + } else { + int k_lim = k_max - k_tile_next * bk; + copy_if(copy_b, make_pred_2d(tBcB, n_max, k_lim), tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write)); + } + cp_async_fence(); + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_next; } + smem_pipe_write = smem_pipe_read; + smem_pipe_read = (smem_pipe_read == K_PIPE_MAX-1) ? 0 : smem_pipe_read+1; + } + } + } + + // + // Epilogue + // + + if (is_interior) { + axpby(alpha, tCrC, beta, tCgC); + } else { + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) { + if (get<0>(tCcC(i)) < m_max && get<1>(tCcC(i)) < n_max) { + tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); + } + } + } +} + +// ─── Launch wrapper ─────────────────────────────────────────────────────── +// BM: compile-time M tile size (must be >= 16, multiple of 16) +// BN: compile-time N tile size (must be >= 64, multiple of 64) +// K32: use m16n8k32 MMA atom (true) vs m16n8k16 (false) +template +void gemm_i8(int m, int n, int k, + int32_t alpha, + int8_t const* A, int ldA, + int8_t const* B, int ldB, + int32_t beta, + int32_t* C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + static_assert(BM >= 16 && BM % 16 == 0, "BM must be >= 16 and multiple of 16"); + static_assert(BN >= 64 && BN % 64 == 0, "BN must be >= 64 and multiple of 64"); + + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) K-major + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) K-major + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) M-major + + // CTA tile + auto bM = Int{}; + auto bN = Int{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline stages + + // Swizzled smem: Swizzle<2,4,3>, 8-row atom x 64 cols + auto smem_atom = composition(Swizzle<2, 4, 3>{}, + Layout, + Stride<_64, _1>>{}); + auto sA = tile_to_shape(smem_atom, make_shape(bM, bK, bP)); + auto sB = tile_to_shape(smem_atom, make_shape(bN, bK, bP)); + auto sC = make_layout(make_shape(bM, bN)); + + // Warp layout: always 8 warps (256 threads) for better occupancy + constexpr int WM = BM >= 32 ? 2 : 1; + constexpr int WN = 8 / WM; + constexpr int NUM_THREADS = WM * WN * 32; + + // A copy: adapt thread layout and vectorization to BM and thread count + constexpr int thr_rows_A = BM < 64 ? BM : 64; + constexpr int thr_kcols_A = NUM_THREADS / thr_rows_A; + constexpr int vals_A = 64 / thr_kcols_A; + + using CpType_A = conditional_t= 16, uint128_t, + conditional_t= 8, uint64_t, uint32_t>>; + + TiledCopy copyA = make_tiled_copy( + Copy_Atom, int8_t>{}, + Layout, Int>, + Stride, _1>>{}, // Thr layout + Layout>>{}); // Val layout + + // B copy: BNx64, 128-bit cp.async + constexpr int thr_rows_B = NUM_THREADS / 4; + TiledCopy copyB = make_tiled_copy( + Copy_Atom, int8_t>{}, + Layout, _4>, Stride<_4, _1>>{}, // Thr layout + Layout>{}); // Val layout + + // MMA atom: k16 vs k32 + using MmaAtom = conditional_t; + using S2RCopyA = conditional_t; + using S2RCopyB = conditional_t; + + TiledMMA mmaC = make_tiled_mma(MmaAtom{}, + Layout, Int, _1>>{}); + + Copy_Atom s2r_atom_A; + Copy_Atom s2r_atom_B; + + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + + constexpr int kSwizzle = 3; // 2^3=8 N-tiles per L2 panel + auto kernel_fptr = gemm_device; + + cudaFuncSetAttribute(kernel_fptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + cudaFuncSetAttribute(kernel_fptr, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + kernel_fptr<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, s2r_atom_A, + B, dB, sB, copyB, s2r_atom_B, + C, dC, sC, mmaC, + alpha, beta); +} + +// ─── CPU reference ──────────────────────────────────────────────────────── +// C(M,N) = A(M,K) * B(N,K)^T, A K-major, B K-major, C N-major (row-major) +void cpu_i8gemm(int m, int n, int k, + int8_t const* A, int8_t const* B, int32_t* C) +{ + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + int32_t acc = 0; + for (int p = 0; p < k; ++p) { + acc += int32_t(A[i * k + p]) * int32_t(B[j * k + p]); + } + C[i * n + j] = acc; + } + } +} + +// ─── Main ───────────────────────────────────────────────────────────────── +int main(int argc, char** argv) +{ + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " + << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 8) { + std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl; + return 0; + } + + std::cout << "Using device 0: " << props.name + << " (SM" << props.major * 10 + props.minor + << ", " << props.multiProcessorCount << " SMs)" << std::endl; + + int m = 4096; + if (argc >= 2) sscanf(argv[1], "%d", &m); + int n = 4096; + if (argc >= 3) sscanf(argv[2], "%d", &n); + int k = 4096; + if (argc >= 4) sscanf(argv[3], "%d", &k); + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + + int ldA = k, ldB = k, ldC = n; + + thrust::host_vector h_A(m * k), h_B(n * k); + thrust::host_vector h_C(m * n, 0); + + srand(42); + for (int i = 0; i < m * k; ++i) h_A[i] = (int8_t)((rand() % 7) - 3); + for (int i = 0; i < n * k; ++i) h_B[i] = (int8_t)((rand() % 7) - 3); + + thrust::device_vector d_A = h_A, d_B = h_B; + thrust::device_vector d_C(m * n, 0); + + // + // Correctness verification on a small problem + // + + { + int vm = 100, vn = 256, vk = 192; // non-aligned sizes + std::cout << "\nVerification: M=" << vm << " N=" << vn << " K=" << vk << std::endl; + + thrust::host_vector vA(vm * vk), vB(vn * vk); + srand(123); + for (int i = 0; i < vm * vk; ++i) vA[i] = (int8_t)((rand() % 7) - 3); + for (int i = 0; i < vn * vk; ++i) vB[i] = (int8_t)((rand() % 7) - 3); + + // CPU reference + thrust::host_vector h_ref(vm * vn, 0); + cpu_i8gemm(vm, vn, vk, vA.data(), vB.data(), h_ref.data()); + + // GPU kernel (swap A<->B and m<->n for N-major C output) + thrust::device_vector d_vA = vA, d_vB = vB; + thrust::device_vector d_vC(vm * vn, 0); + gemm_i8<128, 128, false>(vn, vm, vk, 1, + d_vB.data().get(), vk, + d_vA.data().get(), vk, + 0, d_vC.data().get(), vn); + CUTE_CHECK_LAST(); + cudaDeviceSynchronize(); + + thrust::host_vector h_gpu = d_vC; + int errs = 0; + for (int i = 0; i < vm * vn; ++i) { + if (h_gpu[i] != h_ref[i]) { + if (errs < 5) + std::cerr << " MISMATCH [" << i << "]: gpu=" << h_gpu[i] + << " cpu=" << h_ref[i] << std::endl; + ++errs; + } + } + std::cout << "Verification: " << (errs == 0 ? "PASS" : "FAIL") + << " (" << errs << " errors)" << std::endl; + if (errs > 0) return 1; + } + + // + // Timing + // + + double gops = 2.0 * m * n * k * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + // Warmup + gemm_i8<128, 128, false>(n, m, k, 1, + d_B.data().get(), ldB, + d_A.data().get(), ldA, + 0, d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + cudaDeviceSynchronize(); + + // k16 timing + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm_i8<128, 128, false>(n, m, k, 1, + d_B.data().get(), ldB, + d_A.data().get(), ldA, + 0, d_C.data().get(), ldC); + } + double k16_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + + // k32 timing + gemm_i8<128, 128, true>(n, m, k, 1, + d_B.data().get(), ldB, + d_A.data().get(), ldA, + 0, d_C.data().get(), ldC); + cudaDeviceSynchronize(); + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm_i8<128, 128, true>(n, m, k, 1, + d_B.data().get(), ldB, + d_A.data().get(), ldA, + 0, d_C.data().get(), ldC); + } + double k32_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + + printf("\nINT8_GEMM_k16: [%6.1f]GOp/s (%6.4f)ms\n", gops / k16_time, k16_time * 1000); + printf("INT8_GEMM_k32: [%6.1f]GOp/s (%6.4f)ms\n", gops / k32_time, k32_time * 1000); + + return 0; +}