Skip to content

Commit

Permalink
Add option for larger LDS vecSize (#476)
Browse files Browse the repository at this point in the history
* Add kpack to kernel arg and refactor getValuesFromDotOperandLayout

* Add kpack in the tuning space

* Add kpack in test_dot
  • Loading branch information
zhanglx13 committed Feb 14, 2024
1 parent d6f14d3 commit 35edd6a
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 106 deletions.
18 changes: 2 additions & 16 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,9 @@ compared to 1*64 when the hasLeadingOffset is false.
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;

int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// Note: the following settings is customized to avoid
// **load** bank conflicts
//
// vecSize is set to k_base, which is the number of elements each
// workitem loads for one mfma instruction.
// For now, the k_base rules are as follows
// 1. All selected mfma instructions produce a single block
// 2. For f16 data type, 2 VGPRs are used for operand A --> k_base = 4
// 3. For non-f16 data types, 1 VGPR are used for operand A
// k_base = 32 / elemTypeInBits
// 4. TODO: what about f64?
//
// vecSize is set to kWidth of the dotop layout
int vecSize = dotOpEnc.getKWidth();
// maxPhase is set to SIMDWidth / perPhase
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
// On MI300, fp8 and int8 mfma is assigned to VGPRs for the operands
if (3 == mfmaEnc.getVersionMajor() && 8 == typeWidthInBit)
vecSize = 8;
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
// TODO (zhanglx): figure out better parameters for mfma4
auto mDim = mfmaEnc.getMDim();
Expand Down
7 changes: 3 additions & 4 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ std::unique_ptr<Pass> createTritonAMDGPUDotSlicingPass(int sliceKTile = 0);
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
int matrixInstructionSize = 0,
bool enableWmmaTransform = false);
std::unique_ptr<Pass> createTritonAMDGPUAccelerateMatmulPass(
std::string archGenName = std::string(), int matrixInstructionSize = 0,
int kpack = 1, bool enableWmmaTransform = false);

std::unique_ptr<Pass> createTritonGPUPrefetchPass();

Expand Down
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir:
Option<"matrixInstructionSize", "matrix-instruction-size",
"int32_t", /*default*/"0",
"enforce matrix instruction MN size">,
Option<"kpack", "kpack",
"int32_t", /*default*/"1",
"Kwidth / k_base">,
Option<"enableWmmaTransform", "enable-wmma-transform",
"bool", /*default*/"false",
"temporary option, required for lit tests only">
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class MfmaInsn {
unsigned getMDim();
unsigned getNDim();
StringRef getInsnName();
unsigned getKBase();
};
} // namespace mlir

Expand Down
99 changes: 72 additions & 27 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ struct DotOpMFMAConversionHelper {
MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
mfmaInsnName = (*maybeMfmaInsn).getInsnName();

mfmaInsnName = (*maybeMfmaInsn).getInsnName();
unsigned k_base = (*maybeMfmaInsn).getKBase();

auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
Expand All @@ -207,10 +208,11 @@ struct DotOpMFMAConversionHelper {
auto numRepN = repB[1];
auto numRepK = repA[1];

ValueTable ha = getValuesFromDotOperandLayoutStruct(
loadedA, numRepM, numRepK, kWidth, aTensorTy.getElementType());
ValueTable hb = getValuesFromDotOperandLayoutStruct(
loadedB, numRepN, numRepK, kWidth, aTensorTy.getElementType());
auto operandA = getValuesFromDotOperandLayoutStruct(
loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType());
auto operandB = getValuesFromDotOperandLayoutStruct(
loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType());

auto dstElemTy = dTensorTy.getElementType();
auto fc =
typeConverter->unpackLLElements(loc, loadedC, rewriter, dstElemTy);
Expand All @@ -231,12 +233,15 @@ struct DotOpMFMAConversionHelper {
vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v],
i32_val(v));
}

acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++) {
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInsnName, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaInsnName, ha[{m, k}], hb[{n, k}], acc);
}
for (size_t k = 0; k < numRepK; k++)
for (int kpack = 0; kpack < kWidth / k_base; ++kpack)
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}],
operandA[kpack][{m, k}], acc)
: generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}],
operandB[kpack][{n, k}], acc);
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
Expand All @@ -254,32 +259,72 @@ struct DotOpMFMAConversionHelper {
return success();
}

/**
* @brief Converts dot operand structure to value table and converts types appropriate for mfma instructions
*/
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth,
Type type) const {
/**
* @brief extract vector from rawElems based on kWidth and k_base
* rawElems is a vector of kWidth elements. We need to prepare vector(s) of
* k_base elements for each mfma instruction
*/
SmallVector<Value> extractOperands(Value rawElems, int kWidth, int k_base,
Type type) const {
int kpack = kWidth / k_base;
SmallVector<Value> results;
auto vecTy = vec_ty(type, k_base);
for (int k = 0; k < kpack; ++k) {
Value vec = undef(vecTy);
for (int elemId = 0; elemId < k_base; ++elemId) {
auto val =
extract_element(type, rawElems, i32_val(elemId + k * k_base));
vec = insert_element(vecTy, vec, val, i32_val(elemId));
}
if (type.getIntOrFloatBitWidth() == 8) {
if (4 == k_base)
// This is for int8 on pre- MI300 GPUs
results.push_back(bitcast(vec, i32_ty));
if (8 == k_base)
results.push_back(bitcast(vec, i64_ty));
} else
results.push_back(vec);
}
return results;
}

/**
* @brief Converts dot operand structure to value table and converts types
* appropriate for mfma instructions
*/
SmallVector<ValueTable>
getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth,
int k_base, Type type) const {
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
ValueTable vals;
ValueTable vals1;
int kpack = kWidth / k_base;
SmallVector<ValueTable> dotOpVals(kpack);
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
auto rawElems = elems[n1 * i + j];
Value convertedElems;

if (type.isF32()) {
convertedElems = extract_element(type, rawElems, i32_val(0));
} else if (type.getIntOrFloatBitWidth() == 8) {
if (kWidth == 4)
convertedElems = bitcast(rawElems, i32_ty);
if (kWidth == 8)
convertedElems = bitcast(rawElems, i64_ty);
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k));
}
} else {
assert(type.isBF16() || type.isF16());
convertedElems = rawElems;
SmallVector<Value> vals;
if (type.getIntOrFloatBitWidth() == 8) {
vals = extractOperands(rawElems, kWidth, k_base, i8_ty);
} else if (type.isBF16()) {
vals = extractOperands(rawElems, kWidth, k_base, i16_ty);
} else {
assert(type.isF16() && "Unsupported data type");
vals = extractOperands(rawElems, kWidth, k_base, f16_ty);
}
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] = vals[k];
}
}
vals[{i, j}] = convertedElems;
}
}
return vals;
return dotOpVals;
}
};

Expand Down
45 changes: 32 additions & 13 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,24 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {

class BlockedToMFMA : public mlir::RewritePattern {
int mfmaVersion;
int kpack;
int enforcedNonKDim;

public:
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim)
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim,
int kpack)
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {}
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kpack(kpack) {}

bool isSecondDot(tt::DotOp &dotOp) const {
SetVector<Operation *> slices;
mlir::getBackwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<tt::DotOp>(op);
}) != slices.end())
return true;
return false;
}

/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
Expand Down Expand Up @@ -258,8 +270,8 @@ class BlockedToMFMA : public mlir::RewritePattern {
.cast<ttg::BlockedEncodingAttr>()
.getOrder();

// kWidth is a number of consecutive elements per one instruction per one
// thread
// kWidth is initialized as k_base, which is the number of elements hold by
// one thread per mfma instruction
auto kWidth = -1;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
Expand All @@ -273,6 +285,14 @@ class BlockedToMFMA : public mlir::RewritePattern {
if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4)
kWidth = kDim;
assert(kWidth != -1);

// We want to extend kWidth by kpack (kpack=1 means no extension)
// to increase ds_read vector size
// However, in FA, the second dot can only use kWidth = k_bse since it's
// limited by the result of the first dot, which is of mfmaLayout.
if (!isSecondDot(dotOp))
kWidth *= kpack;

auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
Expand Down Expand Up @@ -367,11 +387,11 @@ class TritonAMDGPUAccelerateMatmulPass
TritonAMDGPUAccelerateMatmulPass> {
public:
TritonAMDGPUAccelerateMatmulPass() = default;
TritonAMDGPUAccelerateMatmulPass(StringRef archGen,
int matrixInstructionSize,
bool enableWmmaTransform) {
TritonAMDGPUAccelerateMatmulPass(StringRef archGen, int matrixInstructionSize,
int kpack, bool enableWmmaTransform) {
this->archGenerationName = archGen.data();
this->matrixInstructionSize = matrixInstructionSize;
this->kpack = kpack;
this->enableWmmaTransform = enableWmmaTransform;
}
void runOnOperation() override {
Expand All @@ -384,7 +404,7 @@ class TritonAMDGPUAccelerateMatmulPass
MatrixCoreVersion::CDNA_MFMA2 == matrixCoreVer ||
MatrixCoreVersion::CDNA_MFMA3 == matrixCoreVer) {
patterns.add<::BlockedToMFMA>(context, getMfmaVersion(matrixCoreVer),
matrixInstructionSize);
matrixInstructionSize, kpack);
} else if (MatrixCoreVersion::RDNA_WMMA == matrixCoreVer &&
enableWmmaTransform) {
patterns.add<::BlockedToWMMA>(context);
Expand All @@ -395,10 +415,9 @@ class TritonAMDGPUAccelerateMatmulPass
}
};

std::unique_ptr<Pass>
mlir::createTritonAMDGPUAccelerateMatmulPass(std::string archGen,
int matrixInstructionSize,
bool enableWmmaTransform) {
std::unique_ptr<Pass> mlir::createTritonAMDGPUAccelerateMatmulPass(
std::string archGen, int matrixInstructionSize, int kpack,
bool enableWmmaTransform) {
return std::make_unique<TritonAMDGPUAccelerateMatmulPass>(
archGen, matrixInstructionSize, enableWmmaTransform);
archGen, matrixInstructionSize, kpack, enableWmmaTransform);
}
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,5 +859,6 @@ unsigned MfmaInsn::getKDim() { return attr.k; }
unsigned MfmaInsn::getMDim() { return attr.m; }
unsigned MfmaInsn::getNDim() { return attr.n; }
StringRef MfmaInsn::getInsnName() { return attr.insn; }
unsigned MfmaInsn::getKBase() { return attr.k_base;}

} // namespace mlir
5 changes: 3 additions & 2 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1875,9 +1875,10 @@ void init_triton_ir(py::module &&m) {
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritonamdgpu_accelerate_matmul_pass",
[](mlir::PassManager &self, const std::string archGenName, int instrSize) {
[](mlir::PassManager &self, const std::string archGenName,
int instrSize, int kpack) {
self.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass(
archGenName, instrSize));
archGenName, instrSize, kpack));
})
.def("add_tritongpu_optimize_dot_operands_pass",
[](mlir::PassManager &self) {
Expand Down
Loading

0 comments on commit 35edd6a

Please sign in to comment.