From cde2e2ec1b3533a48a5f210d3bcfba51e71c6b04 Mon Sep 17 00:00:00 2001 From: "chenwei.gavin" Date: Tue, 30 Apr 2024 12:00:09 +0800 Subject: [PATCH] [feat]: Support weight only gemm with 2bit --- .../gemm/kernel/default_fpA_intB_traits.h | 2 + .../gemm/kernel/mixed_gemm_B_layout.h | 27 ++ .../threadblock/default_dq_mma_multistage.h | 5 +- .../gemm/threadblock/default_mma.h | 154 +++++++++++ .../gemm/threadblock/default_mma_bf16.h | 101 ++++++++ .../dq_mma_multistage_finegrained.h | 32 ++- .../interleaved_numeric_conversion.h | 239 ++++++++++++++++++ .../cutlass_kernels/cutlass_preprocessors.cpp | 107 ++++++++ .../cutlass_kernels/cutlass_preprocessors.h | 4 +- .../cutlass_kernels/cutlass_type_conversion.h | 2 + .../bf16_int2_gemm_fg_scalebias.cu | 31 +++ .../bf16_int2_gemm_fg_scaleonly.cu | 31 +++ .../fpA_intB_gemm/bf16_int2_gemm_per_col.cu | 31 +++ .../fp16_int2_gemm_fg_scalebias.cu | 29 +++ .../fp16_int2_gemm_fg_scaleonly.cu | 28 ++ .../fpA_intB_gemm/fp16_int2_gemm_per_col.cu | 28 ++ .../fpA_intB_gemm/fpA_intB_gemm_template.h | 9 +- .../python/generate_kernels.py | 3 + 18 files changed, 846 insertions(+), 17 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index ee084116a..ee1189c51 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -141,6 +141,8 @@ struct MixedGemmArchTraits::value #ifdef ENABLE_FP8 || cutlass::platform::is_same::value>::type +#else + >::type #endif > { diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index a1712431e..8ac7f0723 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -113,6 +113,24 @@ template using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; +template + struct LayoutDetailsB < TypeA, + uint2b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + + template struct LayoutDetailsB= 90>::type> { @@ -131,6 +149,15 @@ struct LayoutDetailsB +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index 17c634655..59d372e4f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -230,8 +230,9 @@ struct DqMma::value, "Mma multistage must dequantize after ldsm"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h index ad6c7496e..d7f0736b8 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -124,6 +124,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage /// (stage>=3) @@ -232,6 +280,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage @@ -287,6 +388,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #endif // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 77af81005..876d23258 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -244,6 +244,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + template < /// Layout type for A matrix operand typename LayoutA, @@ -348,6 +396,59 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index 2d34d43cb..abd690804 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -562,6 +562,7 @@ class DqMmaMultistage (-Base::kStages + 1);) { @@ -569,6 +570,8 @@ class DqMmaMultistage; + FragmentOperandB converted_frag_B_operand; // Computes a warp-level GEMM on data held in shared memory // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate CUTLASS_PRAGMA_UNROLL @@ -588,23 +591,26 @@ class DqMmaMultistagewarp_tile_iterator_B_.set_kgroup_index( (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(idx + 1) % 2]); ++this->warp_tile_iterator_B_; } + if (warp_tileB_k_compute_offset == 0) { + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[idx % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + using Converter + = cutlass::NumericArrayConverter; + converted_frag_B_operand = Converter::convert(converted_frag_B); + } - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + idx++; + } - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680..d18f883f7 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -440,6 +440,245 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i2s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM0_MASK = 0x00030003; + static constexpr uint32_t BOTTOM1_MASK = 0x000c000c; + static constexpr uint32_t TOP0_MASK = 0x00300030; + static constexpr uint32_t TOP1_MASK = 0x00c000c0; + static constexpr uint32_t I2s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i2s = i2s >> 8; + // Extract elt_01 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // Extract elt_89 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[4]) + : "r"(top_i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[5]) + : "r"(top_i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[6]) + : "r"(top_i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[7]) + : "r"(top_i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1026, 1026} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64026402; + // This is the half2 {1 / 4, 1 / 4} represented as an integer. + static constexpr uint32_t ONE_FOUR = 0x34003400; + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + static constexpr uint32_t ONE_SIXTY_FOUR = 0x24002400; + // This is the half2 {-72, -72} represented as an integer. + //static constexpr uint32_t NEG_72 = 0xd480d480; + static constexpr uint32_t NEG_258 = 0xdc08dc08; + static constexpr uint32_t NEG_66 = 0xd420d420; + static constexpr uint32_t NEG_18 = 0xcc80cc80; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t source_i2s = reinterpret_cast(source); + + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x00030003; + static constexpr uint32_t I2s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i2s = source_i2s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i2s >>= sizeof_bits::value; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-130, -130} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC302C302; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + //*/ +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + //printf("convert uint2 to bfloat16\n"); + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp index 84cb50917..f1410fbd3 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -114,6 +114,9 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) case QuantType::W4_AFP8: details = getLayoutDetailsForArchAndQuantType(); break; + case QuantType::W2_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; default: TLLM_THROW("Unsupported quantization type"); } return details; @@ -173,6 +176,12 @@ std::vector get_permutation_map(QuantType quant_type) return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; } + else if (quant_type == QuantType::W2_A16) + { + return {0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, + 42, 43, 50, 51, 58, 59, 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, 6, 7, 14, 15, 22, + 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63}; + } else { TLLM_THROW("Invalid quantization type for LDSM permutation"); @@ -350,6 +359,32 @@ void subbyte_transpose_impl( } } } + else if constexpr (bits_per_elt == 2) + { + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) + { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0x3 & (cache_buf[ii][jj_byte] >> (2 * jj_bit_offset)); + uint8_t tgt_elt = 0x3 & (cache_buf[jj][ii_byte] >> (2 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (~(0x3 << (2 * jj_bit_offset))); + cache_buf[jj][ii_byte] &= (~(0x3 << (2 * ii_bit_offset))); + + cache_buf[ii][jj_byte] |= (tgt_elt << (2 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (2 * ii_bit_offset)); + } + } + } else { TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); @@ -400,6 +435,10 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quanti { subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); } + else if (quant_type == QuantType::W2_A16) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); @@ -485,6 +524,70 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz } } +void add_bias_and_interleave_int2s_inplace(int8_t* packed_int2_tensor, const size_t num_elts) +{ + const int num_bytes = num_elts / 4; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) + { + int8_t transformed_packed_int2s = 0; + int8_t transformed_elt0 + = (int8_t(packed_int2_tensor[ii] << 6) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt1 + = (int8_t(packed_int2_tensor[ii] << 4) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt2 + = (int8_t(packed_int2_tensor[ii] << 2) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt3 = (packed_int2_tensor[ii] >> 6) + 2; + + TLLM_CHECK_WITH_INFO( + transformed_elt0 >= 0 && transformed_elt0 <= 3, "Illegal result for int2 transform (elt0)"); + TLLM_CHECK_WITH_INFO( + transformed_elt1 >= 0 && transformed_elt1 <= 3, "Illegal result for int2 transform (elt1)"); + TLLM_CHECK_WITH_INFO( + transformed_elt2 >= 0 && transformed_elt2 <= 3, "Illegal result for int2 transform (elt2)"); + TLLM_CHECK_WITH_INFO( + transformed_elt3 >= 0 && transformed_elt3 <= 3, "Illegal result for int2 transform (elt3)"); + + // We don't need to mask in these ops since everything should be in the range 0-3 + transformed_packed_int2s |= transformed_elt0; + transformed_packed_int2s |= (transformed_elt1 << 2); + transformed_packed_int2s |= (transformed_elt2 << 4); + transformed_packed_int2s |= (transformed_elt3 << 6); + packed_int2_tensor[ii] = transformed_packed_int2s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt15 ... elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_15 ... elt_5 elt_3 elt_1 elt_14 ... elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int2 tensor must be a multiple of 16 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int2_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) + { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 16; ++dest_idx) + { + const int src_idx = dest_idx < 8 ? 2 * dest_idx : 2 * (dest_idx - 8) + 1; + const int src_shift = 2 * src_idx; + const int dest_shift = 2 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0x3; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { if (quant_type == QuantType::W8_A16) @@ -499,6 +602,10 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size // for conversion to FP16. add_bias_and_interleave_int4s_inplace(tensor, num_elts); } + else if (quant_type == QuantType::W2_A16) + { + add_bias_and_interleave_int2s_inplace(tensor, num_elts); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h index b12fd7372..5b3e62332 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h @@ -33,7 +33,8 @@ enum class QuantType { W8_A16, W4_A16, - W4_AFP8 + W4_AFP8, + W2_A16 }; constexpr int get_weight_quant_bits(QuantType quant_type) @@ -43,6 +44,7 @@ constexpr int get_weight_quant_bits(QuantType quant_type) case QuantType::W8_A16: return 8; case QuantType::W4_A16: return 4; case QuantType::W4_AFP8: return 4; + case QuantType::W2_A16: return 2; default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h index 0ec8ab2e3..501ff6526 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h @@ -18,7 +18,9 @@ #include #include +#if defined(ENABLE_FP8) #include +#endif #include "cutlass/bfloat16.h" #include "cutlass/float8.h" diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..a5b523ea4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..60d2e0903 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu new file mode 100644 index 000000000..1d00561be --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..0966ed761 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..20f17fadb --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu new file mode 100644 index 000000000..a000397f8 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 0d32045eb..aba7a7468 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -78,7 +78,8 @@ void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, ""); // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. @@ -263,6 +264,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#ifdef ENABLE_FP8 else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { @@ -271,6 +273,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#endif else { generic_mixed_gemm_kernelLauncher constexpr bool is_fp8() { +#ifdef ENABLE_FP8 return std::is_same_v || std::is_same_v; +#else + return false; +#endif } template