Skip to content

Commit

Permalink
[AMD][WMMA] Support dot3d (triton-lang#3674)
Browse files Browse the repository at this point in the history
This PR enables support of 3d dot for RDNA GPUs.
  • Loading branch information
binarman committed May 28, 2024
1 parent 706174d commit 100e2aa
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 131 deletions.
85 changes: 61 additions & 24 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,14 +908,21 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout,

inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) {
unsigned ctaBatchOffset, unsigned ctaOffsetX,
unsigned ctaOffsetY) {
const unsigned elemsPerThreadPerGroup = 8;
auto warpSize = getWarpSize(wmmaLayout);
assert(warpSize == 32);
auto shapePerCta = getShapePerCTATile(wmmaLayout);
auto rank = shapePerCta.size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> elemOffset(rank, 0);
if (rank == 3)
elemOffset[0] = ctaBatchOffset;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]});
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem;
elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
offsets.push_back(elemOffset);
}
}

Expand All @@ -925,9 +932,11 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
RankedTensorType type) {
auto shape = type.getShape();
auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto rank = _warpsPerCTA.size();
assert(rank == 2 || rank == 3);
SmallVector<Value> warpsPerCTA;
for (unsigned i = 0; i < rank; ++i)
warpsPerCTA.push_back(i32_val(_warpsPerCTA[i]));
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();

Value threadId = getThreadId(rewriter, loc);
Expand All @@ -940,20 +949,34 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA,
triton::gpu::getWarpOrder(wmmaLayout));
if (shape[0] >= mnkDim[0]) {
assert(shape[0] % mnkDim[0] == 0);
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0])));
if (shape[rank - 2] >= mnkDim[0]) {
assert(shape[rank - 2] % mnkDim[0] == 0);
multiDimWarpId[rank - 2] =
urem(multiDimWarpId[rank - 2],
i32_val(ceil<unsigned>(shape[rank - 2], mnkDim[0])));
}
if (shape[1] >= mnkDim[1]) {
assert(shape[1] % mnkDim[1] == 0);
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(ceil<unsigned>(shape[1], mnkDim[1])));
if (shape[rank - 1] >= mnkDim[1]) {
assert(shape[rank - 1] % mnkDim[1] == 0);
multiDimWarpId[rank - 1] =
urem(multiDimWarpId[rank - 1],
i32_val(ceil<unsigned>(shape[rank - 1], mnkDim[1])));
}
Value offWarp0 = mul(multiDimWarpId[0], i32_val(mnkDim[0]));
Value offWarp1 = mul(multiDimWarpId[1], i32_val(mnkDim[1]));
return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0),
add(laneId, offWarp1)};
Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0]));
Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1]));

SmallVector<Value> multiDimBase(rank);

multiDimBase[rank - 2] =
add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0);
multiDimBase[rank - 1] = add(laneId, offWarp1);

// TODO: It is assumed when rank = 3, warpsPerCTA is set to
// {numWarps, 1, 1}. We need to generalize the offset computation.
if (rank == 3) {
assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1);
multiDimBase[0] = urem(warpId, i32_val(shape[0]));
}
return multiDimBase;
}

inline SmallVector<SmallVector<unsigned>>
Expand All @@ -964,17 +987,31 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout,
auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape);
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();

SmallVector<unsigned> numWarpsPerDim(2);
auto rank = tensorShape.size();
assert(rank == 2 || rank == 3);

SmallVector<unsigned> numWarpsPerDim(rank, 1);
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
for (unsigned d = 0; d < 2; ++d) {
SmallVector<unsigned> shapePerWarp(rank, 1);
shapePerWarp[rank - 2] = mnkDim[0];
shapePerWarp[rank - 1] = mnkDim[1];
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mnkDim[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, shapePerWarp[d]);
}

for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j);
unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1;
unsigned repM = numWarpsPerDim[rank - 2];
unsigned repN = numWarpsPerDim[rank - 1];
auto warpsPerBatch =
rank == 3 ? std::min<unsigned>(tensorShape[0], warpsPerCTA[0]) : 1;

for (unsigned b = 0; b < repBatch; ++b) {
for (unsigned i = 0; i < repM; ++i) {
for (unsigned j = 0; j < repN; ++j) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j);
}
}
}
return offsets;
Expand Down
8 changes: 6 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ bool supportMFMA(triton::DotOp op) {
auto bShape = bTy.getShape();

auto rank = aShape.size();
assert(bShape.size() == rank);
auto M = aShape[rank - 2];
auto N = bShape[rank - 1];
auto K = aShape[rank - 1];
Expand Down Expand Up @@ -521,8 +522,11 @@ bool supportWMMA(triton::DotOp op) {
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1]))
auto rank = aShape.size();
assert(bShape.size() == rank);
assert(aShape[rank - 1] == bShape[rank - 2]);
if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1],
aShape[rank - 1]))
return false;

return true;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
} else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0],
emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
}
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
Expand Down
75 changes: 55 additions & 20 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,16 +804,22 @@ SmallVector<unsigned>
AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of wmma layout");
assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout");

SmallVector<unsigned> elemsPerThread(rank);
auto mnkDim = getMNKDimPerWMMAInstr();
auto elemsPerThreadPerTile = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
return {ceil<unsigned>(shape[0], mnkDim[0] * warpsPerCTA[0]) *
elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], mnkDim[1] * warpsPerCTA[1]) *
elemsPerThreadPerTile[1]};

if (rank == 3)
elemsPerThread[0] = ceil<unsigned>(shape[0], getWarpsPerCTA()[0]);
elemsPerThread[rank - 2] =
ceil<unsigned>(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) *
elemsPerThreadPerTile[rank - 2];
elemsPerThread[rank - 1] =
ceil<unsigned>(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) *
elemsPerThreadPerTile[rank - 1];
return elemsPerThread;
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Expand Down Expand Up @@ -1605,9 +1611,8 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef<int64_t> operandShape,

unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
constexpr int warpSize = 64;
auto rep = getMFMARepForOperands(shape, kWidth, opIdx);
return rep[0] * rep[1] * rep[2] * kWidth;
return product(rep) * kWidth;
}

SmallVector<unsigned>
Expand Down Expand Up @@ -1646,8 +1651,14 @@ AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());

auto mnkDim = getMNKDimPerWMMAInstr();
return {mnkDim[0] * getWarpsPerCTA()[0], mnkDim[1] * getWarpsPerCTA()[1]};
shapePerCTATile[rank - 2] *= mnkDim[0];
shapePerCTATile[rank - 1] *= mnkDim[1];
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
Expand All @@ -1668,28 +1679,43 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadsPerWarp() const {
return {getMNKDimPerWMMAInstr()[0] / getSizePerThread()[0],
getMNKDimPerWMMAInstr()[1] / getSizePerThread()[1]};
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> threads(rank, 1);
auto mnkInstr = getMNKDimPerWMMAInstr();
threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2];
threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1];
return threads;
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getSizePerThread() const {
return {8, 1};
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> sizePerThread(rank, 1);
sizePerThread[rank - 2] = 8;
sizePerThread[rank - 1] = 1;
return sizePerThread;
}
SmallVector<unsigned>
AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> sizePerThread(rank, 1);
if (opIdx == 0) {
return {1, 16};
sizePerThread[rank - 2] = 1;
sizePerThread[rank - 1] = 16;
} else if (opIdx == 1) {
return {16, 1};
sizePerThread[rank - 2] = 16;
sizePerThread[rank - 1] = 1;
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
return sizePerThread;
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
int opIdx) const {
auto parentShapePerCTA = getShapePerCTATile(shape);
auto rank = shape.size();
assert(rank = 2);
if (opIdx == 0) {
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
} else if (opIdx == 1) {
Expand All @@ -1702,7 +1728,7 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx);
return rep[0] * rep[1] * kWidth;
return product(rep) * kWidth;
}

SmallVector<int64_t>
Expand All @@ -1715,16 +1741,25 @@ AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth,
int opIdx) const {
auto operandTileShape = getWMMAElemsPerInstrForOperands();
assert(operandTileShape.size() == 2);
auto warpsPerCTA = getWarpsPerCTA();
auto rank = operandShape.size();
assert(rank == 2 || rank == 3);
int numRepBatch =
rank == 3 ? std::max<int64_t>(1, operandShape[0] / warpsPerCTA[0]) : 1;
if (opIdx == 0)
return {std::max<int64_t>(1, operandShape[0] /
(operandTileShape[0] * warpsPerCTA[0])),
std::max<int64_t>(1, operandShape[1] / operandTileShape[1])};
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] /
(operandTileShape[0] * warpsPerCTA[rank - 2])),
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
else {
assert(opIdx == 1);
return {std::max<int64_t>(1, operandShape[0] / operandTileShape[0]),
std::max<int64_t>(1, operandShape[1] /
(operandTileShape[1] * warpsPerCTA[1]))};
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
std::max<int64_t>(1, operandShape[rank - 1] / (operandTileShape[1] *
warpsPerCTA[rank - 1]))};
}
}

Expand Down
5 changes: 5 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,6 +3262,11 @@ def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device):
if is_hip():
# hip does not support tf32 precision, so use ieee for all tests
input_precision = "ieee"
if "gfx11" in triton.runtime.driver.active.get_current_target().arch:
if in_dtype_str == "float32":
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
if out_dtype_str == "float16":
pytest.skip(f"{out_dtype_str} has low precision in WMMA dot")
else:
input_precision = "tf32" if in_dtype_str == 'float32' else "ieee"

Expand Down
30 changes: 30 additions & 0 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,33 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}>
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 1, 4]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: wmma_dot_operand3d
tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) {
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
%0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
// CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16>
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
tt.return
}

// CHECK-LABEL: wmma_dot3d
tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) {
// CHECK-COUNT-32: llvm.extractvalue %arg0
// CHECK-COUNT-32: llvm.insertelement
// CHECK-COUNT-32: llvm.extractvalue %arg1
// CHECK-COUNT-32: llvm.insertelement
// CHECK-COUNT-8: llvm.extractvalue %arg2
// CHECK-COUNT-8: llvm.insertelement
// CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma>
// CHECK-COUNT-8: llvm.extractelement
// CHECK-COUNT-8: llvm.insertvalue
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,

auto sharedLayout = cast<SharedEncodingAttr>(aTensorTy.getEncoding());
auto order = sharedLayout.getOrder();
assert((rank == 2 || order[2] == 0) &&
"expect batch to be the slowest dimension");

auto elemTy = aTensorTy.getElementType();
auto kWidth = encoding.getKWidth();
Expand Down
Loading

0 comments on commit 100e2aa

Please sign in to comment.