diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 1588514f608a..c8fdc1c70f5c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -144,23 +144,9 @@ compared to 1*64 when the hasLeadingOffset is false. int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); - // Note: the following settings is customized to avoid - // **load** bank conflicts - // - // vecSize is set to k_base, which is the number of elements each - // workitem loads for one mfma instruction. - // For now, the k_base rules are as follows - // 1. All selected mfma instructions produce a single block - // 2. For f16 data type, 2 VGPRs are used for operand A --> k_base = 4 - // 3. For non-f16 data types, 1 VGPR are used for operand A - // k_base = 32 / elemTypeInBits - // 4. TODO: what about f64? - // + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); // maxPhase is set to SIMDWidth / perPhase - int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; - // On MI300, fp8 and int8 mfma is assigned to VGPRs for the operands - if (3 == mfmaEnc.getVersionMajor() && 8 == typeWidthInBit) - vecSize = 8; int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); // TODO (zhanglx): figure out better parameters for mfma4 auto mDim = mfmaEnc.getMDim(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 3cf5bf209556..64661faa3475 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -17,10 +17,9 @@ std::unique_ptr createTritonAMDGPUDotSlicingPass(int sliceKTile = 0); std::unique_ptr createTritonGPUAccelerateMatmulPass(int computeCapability = 80); -std::unique_ptr -createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), - int matrixInstructionSize = 0, - bool enableWmmaTransform = false); +std::unique_ptr createTritonAMDGPUAccelerateMatmulPass( + std::string archGenName = std::string(), int matrixInstructionSize = 0, + int kpack = 1, bool enableWmmaTransform = false); std::unique_ptr createTritonGPUPrefetchPass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index c5a47830ced4..fcfa089cbe68 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -125,6 +125,9 @@ def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir: Option<"matrixInstructionSize", "matrix-instruction-size", "int32_t", /*default*/"0", "enforce matrix instruction MN size">, + Option<"kpack", "kpack", + "int32_t", /*default*/"1", + "Kwidth / k_base">, Option<"enableWmmaTransform", "enable-wmma-transform", "bool", /*default*/"false", "temporary option, required for lit tests only"> diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 528e3ca4dc28..9cb4de97d744 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -223,6 +223,7 @@ class MfmaInsn { unsigned getMDim(); unsigned getNDim(); StringRef getInsnName(); + unsigned getKBase(); }; } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 179298813fda..10bec3614969 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -185,8 +185,9 @@ struct DotOpMFMAConversionHelper { MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB, mfmaVersion); if (failed(maybeMfmaInsn)) llvm::report_fatal_error("No match found in MFMA database\n"); - else - mfmaInsnName = (*maybeMfmaInsn).getInsnName(); + + mfmaInsnName = (*maybeMfmaInsn).getInsnName(); + unsigned k_base = (*maybeMfmaInsn).getKBase(); auto aEncoding = aTensorTy.getEncoding().cast(); auto bEncoding = bTensorTy.getEncoding().cast(); @@ -207,10 +208,11 @@ struct DotOpMFMAConversionHelper { auto numRepN = repB[1]; auto numRepK = repA[1]; - ValueTable ha = getValuesFromDotOperandLayoutStruct( - loadedA, numRepM, numRepK, kWidth, aTensorTy.getElementType()); - ValueTable hb = getValuesFromDotOperandLayoutStruct( - loadedB, numRepN, numRepK, kWidth, aTensorTy.getElementType()); + auto operandA = getValuesFromDotOperandLayoutStruct( + loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType()); + auto operandB = getValuesFromDotOperandLayoutStruct( + loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType()); + auto dstElemTy = dTensorTy.getElementType(); auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dstElemTy); @@ -231,12 +233,15 @@ struct DotOpMFMAConversionHelper { vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v], i32_val(v)); } + acc = zeroAuxiliarBlocks(subBlocks, acc); - for (size_t k = 0; k < numRepK; k++) { - acc = mfmaLayout.getIsTransposed() - ? generateMFMAOp(mfmaInsnName, hb[{n, k}], ha[{m, k}], acc) - : generateMFMAOp(mfmaInsnName, ha[{m, k}], hb[{n, k}], acc); - } + for (size_t k = 0; k < numRepK; k++) + for (int kpack = 0; kpack < kWidth / k_base; ++kpack) + acc = mfmaLayout.getIsTransposed() + ? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}], + operandA[kpack][{m, k}], acc) + : generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}], + operandB[kpack][{n, k}], acc); acc = reduceSubBlocks(subBlocks, acc); for (unsigned v = 0; v < elemsPerVec; ++v) { fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] = @@ -254,32 +259,72 @@ struct DotOpMFMAConversionHelper { return success(); } -/** - * @brief Converts dot operand structure to value table and converts types appropriate for mfma instructions -*/ - ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth, - Type type) const { + /** + * @brief extract vector from rawElems based on kWidth and k_base + * rawElems is a vector of kWidth elements. We need to prepare vector(s) of + * k_base elements for each mfma instruction + */ + SmallVector extractOperands(Value rawElems, int kWidth, int k_base, + Type type) const { + int kpack = kWidth / k_base; + SmallVector results; + auto vecTy = vec_ty(type, k_base); + for (int k = 0; k < kpack; ++k) { + Value vec = undef(vecTy); + for (int elemId = 0; elemId < k_base; ++elemId) { + auto val = + extract_element(type, rawElems, i32_val(elemId + k * k_base)); + vec = insert_element(vecTy, vec, val, i32_val(elemId)); + } + if (type.getIntOrFloatBitWidth() == 8) { + if (4 == k_base) + // This is for int8 on pre- MI300 GPUs + results.push_back(bitcast(vec, i32_ty)); + if (8 == k_base) + results.push_back(bitcast(vec, i64_ty)); + } else + results.push_back(vec); + } + return results; + } + + /** + * @brief Converts dot operand structure to value table and converts types + * appropriate for mfma instructions + */ + SmallVector + getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth, + int k_base, Type type) const { auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type); ValueTable vals; + ValueTable vals1; + int kpack = kWidth / k_base; + SmallVector dotOpVals(kpack); for (int i = 0; i < n0; i++) { for (int j = 0; j < n1; j++) { auto rawElems = elems[n1 * i + j]; - Value convertedElems; + if (type.isF32()) { - convertedElems = extract_element(type, rawElems, i32_val(0)); - } else if (type.getIntOrFloatBitWidth() == 8) { - if (kWidth == 4) - convertedElems = bitcast(rawElems, i32_ty); - if (kWidth == 8) - convertedElems = bitcast(rawElems, i64_ty); + for (int k = 0; k < kpack; ++k) { + dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k)); + } } else { - assert(type.isBF16() || type.isF16()); - convertedElems = rawElems; + SmallVector vals; + if (type.getIntOrFloatBitWidth() == 8) { + vals = extractOperands(rawElems, kWidth, k_base, i8_ty); + } else if (type.isBF16()) { + vals = extractOperands(rawElems, kWidth, k_base, i16_ty); + } else { + assert(type.isF16() && "Unsupported data type"); + vals = extractOperands(rawElems, kWidth, k_base, f16_ty); + } + for (int k = 0; k < kpack; ++k) { + dotOpVals[k][{i, j}] = vals[k]; + } } - vals[{i, j}] = convertedElems; } } - return vals; + return dotOpVals; } }; diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 35c45fd2c069..0d4ebc4c1a0b 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -131,12 +131,24 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { class BlockedToMFMA : public mlir::RewritePattern { int mfmaVersion; + int kpack; int enforcedNonKDim; public: - BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim) + BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim, + int kpack) : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {} + mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kpack(kpack) {} + + bool isSecondDot(tt::DotOp &dotOp) const { + SetVector slices; + mlir::getBackwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return true; + return false; + } /// @brief Choose MFMA instruction parameters /// @param dot target dot operation @@ -258,8 +270,8 @@ class BlockedToMFMA : public mlir::RewritePattern { .cast() .getOrder(); - // kWidth is a number of consecutive elements per one instruction per one - // thread + // kWidth is initialized as k_base, which is the number of elements hold by + // one thread per mfma instruction auto kWidth = -1; // in mfma 32x32 case argument matrix groups elements in 2 groups // in mfma 16x16 case argument matrix groups elements in 4 groups @@ -273,6 +285,14 @@ class BlockedToMFMA : public mlir::RewritePattern { if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) kWidth = kDim; assert(kWidth != -1); + + // We want to extend kWidth by kpack (kpack=1 means no extension) + // to increase ds_read vector size + // However, in FA, the second dot can only use kWidth = k_bse since it's + // limited by the result of the first dot, which is of mfmaLayout. + if (!isSecondDot(dotOp)) + kWidth *= kpack; + auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); @@ -367,11 +387,11 @@ class TritonAMDGPUAccelerateMatmulPass TritonAMDGPUAccelerateMatmulPass> { public: TritonAMDGPUAccelerateMatmulPass() = default; - TritonAMDGPUAccelerateMatmulPass(StringRef archGen, - int matrixInstructionSize, - bool enableWmmaTransform) { + TritonAMDGPUAccelerateMatmulPass(StringRef archGen, int matrixInstructionSize, + int kpack, bool enableWmmaTransform) { this->archGenerationName = archGen.data(); this->matrixInstructionSize = matrixInstructionSize; + this->kpack = kpack; this->enableWmmaTransform = enableWmmaTransform; } void runOnOperation() override { @@ -384,7 +404,7 @@ class TritonAMDGPUAccelerateMatmulPass MatrixCoreVersion::CDNA_MFMA2 == matrixCoreVer || MatrixCoreVersion::CDNA_MFMA3 == matrixCoreVer) { patterns.add<::BlockedToMFMA>(context, getMfmaVersion(matrixCoreVer), - matrixInstructionSize); + matrixInstructionSize, kpack); } else if (MatrixCoreVersion::RDNA_WMMA == matrixCoreVer && enableWmmaTransform) { patterns.add<::BlockedToWMMA>(context); @@ -395,10 +415,9 @@ class TritonAMDGPUAccelerateMatmulPass } }; -std::unique_ptr -mlir::createTritonAMDGPUAccelerateMatmulPass(std::string archGen, - int matrixInstructionSize, - bool enableWmmaTransform) { +std::unique_ptr mlir::createTritonAMDGPUAccelerateMatmulPass( + std::string archGen, int matrixInstructionSize, int kpack, + bool enableWmmaTransform) { return std::make_unique( - archGen, matrixInstructionSize, enableWmmaTransform); + archGen, matrixInstructionSize, kpack, enableWmmaTransform); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f869dff81ca1..23f1befd2617 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -859,5 +859,6 @@ unsigned MfmaInsn::getKDim() { return attr.k; } unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } +unsigned MfmaInsn::getKBase() { return attr.k_base;} } // namespace mlir diff --git a/python/src/triton.cc b/python/src/triton.cc index 850329ed9a59..7ec01f15841f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1875,9 +1875,10 @@ void init_triton_ir(py::module &&m) { mlir::createTritonGPUAccelerateMatmulPass(computeCapability)); }) .def("add_tritonamdgpu_accelerate_matmul_pass", - [](mlir::PassManager &self, const std::string archGenName, int instrSize) { + [](mlir::PassManager &self, const std::string archGenName, + int instrSize, int kpack) { self.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass( - archGenName, instrSize)); + archGenName, instrSize, kpack)); }) .def("add_tritongpu_optimize_dot_operands_pass", [](mlir::PassManager &self) { diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 0816c7fe2978..235a263484e5 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -25,6 +25,10 @@ num_ctas_list = [1] +def get_matrix_core_version(): + backend = triton.common.backend.get_backend("hip") + return backend.get_matrix_core_version() + def hip_skip(): import inspect return pytest.skip(f"Skipping {inspect.stack()[1][3]}!") @@ -1040,11 +1044,10 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): check_type_supported(in_dtype, device) check_type_supported(out_dtype, device) - backend = triton.common.backend.get_backend("hip") - if backend.get_matrix_core_version() == 3: + if get_matrix_core_version() == 3: if out_dtype == torch.bfloat16 and (in_dtype == tl.float8e4b15 or in_dtype == tl.float8e4b15x4): pytest.skip(f"Type conversion between {in_dtype} and {out_dtype} is not available on hardware") - elif backend.get_matrix_core_version() == 2: + elif get_matrix_core_version() == 2: if out_dtype == torch.bfloat16 and in_dtype != tl.float8e5: pytest.skip(f"Type conversion between {in_dtype} and {out_dtype} is not available on hardware") @@ -1263,8 +1266,7 @@ def matmul(a, b, c_type): def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): check_type_supported(out_dtype, device) - backend = triton.common.backend.get_backend("hip") - if backend.get_matrix_core_version() != 3: + if get_matrix_core_version() != 3: pytest.skip("fp8 data type is not available on hardware") @triton.jit @@ -1619,9 +1621,9 @@ def kernel(X, stride_xm, stride_xn, # --------------- -@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim", +@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, kpack", # FMA Test Dot tests - [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype, 0) + [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype, 0, 1) for shape in [(64, 64, 64), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] @@ -1630,7 +1632,7 @@ def kernel(X, stride_xm, stride_xn, ('float32', 'float32')] if not (allow_tf32 and (in_dtype in ['float16']))] + - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, 0) + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, 0, 1) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], @@ -1648,9 +1650,9 @@ def kernel(X, stride_xm, stride_xn, ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] - if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else + if get_matrix_core_version() == 0 else # MFMA Test Dot tests - [(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim) + [(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] for allow_tf32 in [True, False] @@ -1663,7 +1665,7 @@ def kernel(X, stride_xm, stride_xn, for non_k_dim in [0, 4, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim) + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], [128, 128, 64, 2], @@ -1699,15 +1701,15 @@ def kernel(X, stride_xm, stride_xn, for col_b in [True, False] for in_dtype in ['int8', 'bfloat16', 'float16', 'float32'] for out_dtype in [None] - for non_k_dim in [0, 4, 16, 32]]) -def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, device='cuda'): + for non_k_dim in [0, 4, 16, 32] + for kpack in [1, 2]]) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, kpack, device='cuda'): capability = torch.cuda.get_device_capability() if torch.version.hip is not None: - # TODO consider reenabling this tests when fp8 casing is fixed - if M == 16 and N == 16 and K == 16 and "float8" in in_dtype: - pytest.skip("triton do not generate MFMA instructions for given block size") + if (M < 16 or N < 16) and kpack == 2: + pytest.skip("Skip tests for mfma4 with kpack=2") # set capability to large number to jump over check below # check are not relevant to amd gpu, left them for smaller diff between test_core.py and test_core_amd.py tests if (M, N, K) == (128, 256, 32): @@ -1877,7 +1879,8 @@ def kernel(X, stride_xm, stride_xk, CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, num_warps=num_warps, - matrix_instr_nonkdim=non_k_dim) + matrix_instr_nonkdim=non_k_dim, + kpack = kpack) # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), @@ -1912,8 +1915,7 @@ def kernel(X, stride_xm, stride_xk, # added atol, to loose precision for float16xfloat16->float32 case np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) if torch.version.hip is not None: - backend = triton.common.backend.get_backend("hip") - if backend.get_matrix_core_version() > 0: + if get_matrix_core_version() > 0: ttgir = pgm.asm['ttgir'] if non_k_dim == 0 and ((M == 4 and N == 64) or (M == 64 and N == 4)): m64n4 = "instrShape = [64, 4]" @@ -1935,10 +1937,12 @@ def kernel(X, stride_xm, stride_xk, assert "instrShape = [16, 16]" not in ttgir assert "instrShape = [32, 32]" in ttgir gcn = pgm.asm['amdgcn'] - if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16: + if get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16: assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn - if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8: + if get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8: assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn + if get_matrix_core_version() == 3 and kpack == 2: + assert "ds_read_b128" in gcn and "ds_write_b128" in gcn return # make sure ld/st are vectorized ptx = pgm.asm['ptx'] @@ -2771,7 +2775,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): if transposeA and not transposeB: pytest.skip() - if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0: + if get_matrix_core_version() == 0: pytest.skip("mfma is not available on hardware") # source code for following ttgir: @@ -2862,7 +2866,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): import triton.language.semantic as sem # if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0: - if torch.version.hip is not None and backend.get_matrix_core_version() > 0: + if torch.version.hip is not None and get_matrix_core_version() > 0: kernel[(1, 1, 1)](x_tri, y_tri, z_tri) np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index efc336da3362..2899168ea783 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -100,7 +100,7 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target): def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, - enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile): + enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile, kpack): is_cuda = _is_cuda(target) if is_cuda: capability = target.capability @@ -119,7 +119,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e if is_hip(): gfx_arch = target["gfx_arch"] matrix_inst_size = matrix_inst_type - pm.add_tritonamdgpu_accelerate_matmul_pass(gfx_arch, matrix_inst_size) + pm.add_tritonamdgpu_accelerate_matmul_pass(gfx_arch, matrix_inst_size, kpack) pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() @@ -275,7 +275,8 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): num_stages = kwargs.get("num_stages", 3) waves_per_eu = kwargs.get("waves_per_eu", 0) slice_k_tile = kwargs.get("slice_k_tile", 0) - matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0); + matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0) + kpack = kwargs.get("kpack", 1) enable_warp_specialization = kwargs.get("enable_warp_specialization", False) enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) @@ -284,7 +285,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" + key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{kpack}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) ignore_version = kwargs.get('ignore_version', False) @@ -418,6 +419,7 @@ def compile(fn, **kwargs): waves_per_eu = kwargs.get("waves_per_eu", 0) slice_k_tile = kwargs.get("slice_k_tile", 0) matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0) + kpack = kwargs.get("kpack", 1) enable_fp_fusion = kwargs.get("enable_fp_fusion", True) # TODO[shuhaoj]: Default should be to enable warp specialization once possible enable_warp_specialization = kwargs.get("enable_warp_specialization", False) @@ -477,6 +479,7 @@ def compile(fn, **kwargs): other["waves_per_eu"] = waves_per_eu other["slice_k_tile"] = slice_k_tile other["matrix_instr_nonkdim"] = matrix_instr_nonkdim + other["kpack"] = kpack _device_backend.add_stages(target, extern_libs, stages, other) elif device_type == "xpu": @@ -562,6 +565,7 @@ def compile(fn, **kwargs): "waves_per_eu": waves_per_eu, "slice_k_tile": slice_k_tile, "matrix_instr_nonkdim": matrix_instr_nonkdim, + "kpack": kpack, "enable_warp_specialization": enable_warp_specialization, "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 9db08ce6586f..4da90e1d3594 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -353,6 +353,7 @@ def _call_hook( waves_per_eu, slice_k_tile, matrix_instr_nonkdim, + kpack, enable_warp_specialization, enable_fp_fusion, extern_libs, @@ -364,7 +365,7 @@ def _call_hook( name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])]) - repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, slice_k_tile={slice_k_tile}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})" + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, slice_k_tile={slice_k_tile}, matrix_instr_nonkdim={matrix_instr_nonkdim}, kpack={kpack}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})" key = str(key) class LegacyCompiler: @@ -431,6 +432,7 @@ def get_special_arg(name: str, default=None): waves_per_eu = get_special_arg("waves_per_eu", 0) slice_k_tile = get_special_arg("slice_k_tile", 0) matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0) + kpack = get_special_arg("kpack", 1) enable_warp_specialization = get_special_arg("enable_warp_specialization", False) enable_fp_fusion = get_special_arg("enable_fp_fusion", True) extern_libs = get_special_arg("extern_libs") @@ -508,6 +510,7 @@ def get_special_arg(name: str, default=None): waves_per_eu, slice_k_tile, matrix_instr_nonkdim, + kpack, enable_warp_specialization, enable_fp_fusion, self.debug, @@ -545,6 +548,7 @@ def get_special_arg(name: str, default=None): waves_per_eu, slice_k_tile, matrix_instr_nonkdim, + kpack, enable_warp_specialization, enable_fp_fusion, extern_libs, @@ -563,6 +567,7 @@ def get_special_arg(name: str, default=None): waves_per_eu=waves_per_eu, slice_k_tile=slice_k_tile, matrix_instr_nonkdim=matrix_instr_nonkdim, + kpack=kpack, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, diff --git a/python/triton/third_party/hip/hip_backend.py b/python/triton/third_party/hip/hip_backend.py index 7ef180cfba6f..73fe70f2592b 100644 --- a/python/triton/third_party/hip/hip_backend.py +++ b/python/triton/third_party/hip/hip_backend.py @@ -451,9 +451,10 @@ def add_stages(self, arch: dict, extern_libs: dict, stages: dict, other: dict = waves_per_eu = other["waves_per_eu"] slice_k_tile = other["slice_k_tile"] matrix_instr_nonkdim = other["matrix_instr_nonkdim"] + kpack = other["kpack"] stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim, slice_k_tile)) + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim, slice_k_tile, kpack)) stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu)) diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index 3c42016b855b..51ecdd285d45 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -31,6 +31,7 @@ def get_full_tuning_space(): num_stage_range = [0] waves_per_eu_range = [0] matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] for block_m in block_mn_range: for block_n in block_mn_range: @@ -41,7 +42,8 @@ def get_full_tuning_space(): for num_stages in num_stage_range: for waves_per_eu in waves_per_eu_range: for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim}) + for kpack in kpack_range: + configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}) return configs @@ -65,6 +67,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") num_warps = config.get("num_warps") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: @@ -104,10 +107,9 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm - # For fp8, we want to only use BLOCK_SIZE >= 128 - # For fp16, we want to only use BLOCK_SIZE >= 64 + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 if large_gemm: - if BLOCK_SIZE_M < (128/elemBytes_a) or BLOCK_SIZE_N < (128/elemBytes_a): + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: continue if BLOCK_SIZE_K < 64: continue @@ -147,11 +149,12 @@ def read_config(config): num_stages = config.get('num_stages') waves_per_eu = config.get('waves_per_eu') mfma_instr_size = config.get('matrix_instr_nonkdim') - return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack def gen_kernel_and_configStr_from_config(M, N, K, config, dtype_a, dtype_b, dtype_c): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config) + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) torch_dtype_a = 'fp16' torch_dtype_b = 'fp16' torch_dtype_c = 'fp16' @@ -161,7 +164,7 @@ def gen_kernel_and_configStr_from_config(M, N, K, config, dtype_a, dtype_b, dtyp torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] if dtype_c: torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] - configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_mfma{mfmaInstrSize}" + configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" matmul_def_str = f""" def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False): @@ -181,6 +184,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False): num_stages = {num_stages}, waves_per_eu = {waves_per_eu}, matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, grid=(1,) ) return None @@ -198,6 +202,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False): num_stages = {num_stages}, waves_per_eu = {waves_per_eu}, matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack} ) return c @@ -469,7 +474,7 @@ def init_by_size_and_type(size, dtype, init_type): return input, input_f16 -def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize): +def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" #assert a.is_contiguous(), "Matrix A must be contiguous" @@ -494,13 +499,14 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_ num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim = mfmaInstrSize + matrix_instr_nonkdim = mfmaInstrSize, + kpack = kpack ) return c def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config) + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) torch.manual_seed(0) #a = torch.randn((M, K), device='cuda', dtype=datatype) #b = torch.randn((K, N), device='cuda', dtype=datatype) @@ -508,7 +514,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') # Allocates output. c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) - triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize) + triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack) torch_output = torch.matmul(a_fp16, b_fp16) # print(f"triton_output={triton_output}") # print(f"torch_output={torch_output}") diff --git a/scripts/amd/occ.sh b/scripts/amd/occ.sh index 5dd3d8861b58..d55fe00250c7 100755 --- a/scripts/amd/occ.sh +++ b/scripts/amd/occ.sh @@ -5,14 +5,14 @@ rm -rf ~/.triton/cache/ export MLIR_ENABLE_DUMP=1 +export LLVM_IR_ENABLE_DUMP=1 export AMDGCN_ENABLE_DUMP=1 ## Assume CDNA arch SIMD=4 LDS_SIZE=65536 TOTAL_VGPR=512 -python -u $1 > output.mlir 2>&1 - +$1 > output.mlir 2>&1 LDS_line=$(sed -n '/triton_gpu\.shared\ /p' output.mlir | tail -n 1 | grep -o 'triton_gpu.shared = [0-9]*') numWarps_line=$(sed -n '/triton_gpu\.num-warps/p' output.mlir | tail -n 1 | grep -o 'triton_gpu.num-warps. = [0-9]*') @@ -34,5 +34,10 @@ if [ $occ_LDS -lt $occ_vgpr ];then fi echo "occ: $occ waves/SIMD (occ_LDS: $occ_LDS, occ_vgpr: $occ_vgpr)" -perf=$(tail -n 1 output.mlir | awk '{print $NF}') -printf "perf: %.1f tflops\n" $perf +perf=$(tail -n 2 output.mlir) +echo "$perf" + +## remove distracting info from the assembly +sed -i '/\.loc/d' output.mlir +sed -i '/\.Ltmp.*:/d' output.mlir +sed -i '/AMD clang version/d' output.mlir