Skip to content

Commit f0d8ce5

Browse files
tgymnichtomtor
authored andcommitted
[mlir][AMDGPU] Add scaled floating point conversion ops (llvm#141554)
implement `ScaledExtPackedOp` and `PackedScaledTruncOp`
1 parent 893d7cd commit f0d8ce5

File tree

6 files changed

+1601
-1
lines changed

6 files changed

+1601
-1
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,38 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115+
def AMDGPU_ScaledExtPackedOp
116+
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
117+
Arguments<(
118+
ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>,
119+
VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8],
120+
[F4E2M1FN]>]>:$source,
121+
F32:$scale,
122+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
123+
Results<(
124+
outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>,
125+
FixedVectorOfLengthAndType<[2], [F16]>,
126+
FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> {
127+
let summary = "Extend a vector of packed floating point values";
128+
129+
let description = [{
130+
Extend and scale two packed floats in `source[index]` to two floats and
131+
return them.
132+
133+
This rather unusual signature arises from the fact that AMD GPUs cannot
134+
easily work with sub 32-bit quantities, so the compiler intrinsics for
135+
extending 8-bit floats (which are, currently, the only way to work with
136+
this operation) take packed vectors of 2 such floats.
137+
138+
If the passed-in vector has fewer than two elements, or the input is scalar,
139+
the remaining values in the <2 x i8> will be filled with
140+
undefined values as needed.
141+
}];
142+
let assemblyFormat = [{
143+
attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
144+
}];
145+
}
146+
115147
def AMDGPU_PackedTrunc2xFp8Op :
116148
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
117149
Arguments<(ins F32:$sourceA,
@@ -139,6 +171,36 @@ def AMDGPU_PackedTrunc2xFp8Op :
139171
let hasVerifier = 1;
140172
}
141173

174+
def AMDGPU_PackedScaledTruncOp
175+
: AMDGPU_Op<"packed_scaled_trunc", [Pure]>,
176+
Arguments<(ins VectorOfLengthAndType<[1, 2], [F32, F16, BF16]>:$source,
177+
F32:$scale,
178+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index,
179+
Optional<AnyTypeOf<
180+
[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
181+
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>>:$existing)>,
182+
Results<(
183+
outs AnyTypeOf<[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
184+
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> {
185+
let summary = "Round two floats into a packed vector of floats";
186+
let description = [{
187+
Scale and round the inputs `source` (which is undefined if not
188+
specified) into the low or high word (bottom two or top two) elements
189+
of the returned vector, keeping the other two elements of `existing`
190+
unchanged if present (or undefined if it was not passed in).
191+
192+
The reason for this odd signature is that AMD GPUs cannot easily work with
193+
sub-registers, and so the conversion intrinsics take 32-bit wide
194+
packed vectors of float values.
195+
}];
196+
let assemblyFormat = [{
197+
attr-dict $source `into` ($existing^):(`undef`)? `[` $index `]`
198+
`,` $scale
199+
`:` type($source) `to` type($res) (`into` type($existing)^)?
200+
}];
201+
let hasVerifier = 1;
202+
}
203+
142204
def AMDGPU_PackedStochRoundFp8Op :
143205
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
144206
Arguments<(ins F32:$source,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/ADT/STLExtras.h"
2525
#include "llvm/ADT/TypeSwitch.h"
2626
#include "llvm/Support/Casting.h"
27+
#include "llvm/Support/ErrorHandling.h"
2728
#include <optional>
2829

2930
namespace mlir {
@@ -1174,6 +1175,32 @@ struct PackedStochRoundFp8OpLowering final
11741175
PackedStochRoundFp8OpAdaptor adaptor,
11751176
ConversionPatternRewriter &rewriter) const override;
11761177
};
1178+
1179+
struct ScaledExtPackedOpLowering final
1180+
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1181+
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1182+
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1183+
chipset(chipset) {}
1184+
Chipset chipset;
1185+
1186+
LogicalResult
1187+
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1188+
ConversionPatternRewriter &rewriter) const override;
1189+
};
1190+
1191+
struct PackedScaledTruncOpLowering final
1192+
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1193+
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1194+
Chipset chipset)
1195+
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1196+
chipset(chipset) {}
1197+
Chipset chipset;
1198+
1199+
LogicalResult
1200+
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1201+
ConversionPatternRewriter &rewriter) const override;
1202+
};
1203+
11771204
} // end namespace
11781205

11791206
LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
@@ -1230,6 +1257,165 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
12301257
return success();
12311258
}
12321259

1260+
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1261+
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1262+
ConversionPatternRewriter &rewriter) const {
1263+
Location loc = op.getLoc();
1264+
if (chipset != kGfx950)
1265+
return rewriter.notifyMatchFailure(
1266+
loc, "Scaled fp conversion instructions are not available on target "
1267+
"architecture and their emulation is not implemented");
1268+
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1269+
1270+
Value source = adaptor.getSource();
1271+
Value scale = adaptor.getScale();
1272+
1273+
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1274+
Type sourceElemType = sourceVecType.getElementType();
1275+
VectorType destVecType = cast<VectorType>(op.getResult().getType());
1276+
Type destElemType = destVecType.getElementType();
1277+
1278+
VectorType packedVecType;
1279+
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1280+
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1281+
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1282+
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
1283+
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1284+
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1285+
} else {
1286+
llvm_unreachable("invalid element type for scaled ext");
1287+
}
1288+
1289+
// Extend to a packedVectorType
1290+
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1291+
Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
1292+
if (!sourceVecType) {
1293+
longVec = rewriter.create<LLVM::InsertElementOp>(
1294+
loc, longVec, source, createI32Constant(rewriter, loc, 0));
1295+
} else {
1296+
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1297+
Value idx = createI32Constant(rewriter, loc, i);
1298+
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1299+
longVec =
1300+
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1301+
}
1302+
}
1303+
source = longVec;
1304+
}
1305+
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1306+
1307+
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1308+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1309+
op, destVecType, i32Source, scale, op.getIndex());
1310+
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1311+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1312+
op, destVecType, i32Source, scale, op.getIndex());
1313+
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1314+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1315+
op, destVecType, i32Source, scale, op.getIndex());
1316+
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1317+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1318+
op, destVecType, i32Source, scale, op.getIndex());
1319+
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1320+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1321+
op, destVecType, i32Source, scale, op.getIndex());
1322+
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1323+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1324+
op, destVecType, i32Source, scale, op.getIndex());
1325+
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1326+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1327+
op, destVecType, i32Source, scale, op.getIndex());
1328+
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1329+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1330+
op, destVecType, i32Source, scale, op.getIndex());
1331+
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1332+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1333+
op, destVecType, i32Source, scale, op.getIndex());
1334+
else
1335+
return failure();
1336+
1337+
return success();
1338+
}
1339+
1340+
LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1341+
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1342+
ConversionPatternRewriter &rewriter) const {
1343+
Location loc = op.getLoc();
1344+
if (chipset != kGfx950)
1345+
return rewriter.notifyMatchFailure(
1346+
loc, "Scaled fp conversion instructions are not available on target "
1347+
"architecture and their emulation is not implemented");
1348+
Type v2i16 = getTypeConverter()->convertType(
1349+
VectorType::get(2, rewriter.getI16Type()));
1350+
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1351+
1352+
Type resultType = op.getResult().getType();
1353+
Type resultElemType = getElementTypeOrSelf(resultType);
1354+
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1355+
Type sourceElemType = sourceVecType.getElementType();
1356+
1357+
Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1358+
1359+
Value source = adaptor.getSource();
1360+
Value scale = adaptor.getScale();
1361+
Value existing = adaptor.getExisting();
1362+
if (existing)
1363+
existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
1364+
else
1365+
existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
1366+
1367+
if (sourceVecType.getNumElements() < 2) {
1368+
Value c0 = createI32Constant(rewriter, loc, 0);
1369+
Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1370+
VectorType v2 = VectorType::get(2, sourceElemType);
1371+
source = rewriter.create<LLVM::ZeroOp>(loc, v2);
1372+
source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1373+
}
1374+
1375+
Value sourceA, sourceB;
1376+
if (sourceElemType.isF32()) {
1377+
Value c0 = createI32Constant(rewriter, loc, 0);
1378+
Value c1 = createI32Constant(rewriter, loc, 1);
1379+
sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1380+
sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
1381+
}
1382+
1383+
Value result;
1384+
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1385+
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
1386+
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1387+
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1388+
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
1389+
loc, intResultType, existing, source, scale, op.getIndex());
1390+
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1391+
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1392+
loc, intResultType, existing, source, scale, op.getIndex());
1393+
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1394+
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
1395+
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1396+
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1397+
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
1398+
loc, intResultType, existing, source, scale, op.getIndex());
1399+
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1400+
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1401+
loc, intResultType, existing, source, scale, op.getIndex());
1402+
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1403+
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
1404+
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1405+
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1406+
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
1407+
loc, intResultType, existing, source, scale, op.getIndex());
1408+
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1409+
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1410+
loc, intResultType, existing, source, scale, op.getIndex());
1411+
else
1412+
return failure();
1413+
1414+
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1415+
op, getTypeConverter()->convertType(resultType), result);
1416+
return success();
1417+
}
1418+
12331419
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
12341420
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
12351421
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1733,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
15471733
ROCDL::RawPtrBufferAtomicCmpSwap>,
15481734
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
15491735
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1550-
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1736+
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1737+
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
15511738
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
15521739
chipset);
15531740
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ LogicalResult PackedStochRoundFp8Op::verify() {
6060
return success();
6161
}
6262

63+
//===----------------------------------------------------------------------===//
64+
// mxfp float ops
65+
//===----------------------------------------------------------------------===//
66+
LogicalResult PackedScaledTruncOp::verify() {
67+
if (getExisting() && getExisting().getType() != getResult().getType())
68+
return emitOpError("existing values must have same type as result");
69+
return success();
70+
}
71+
6372
//===----------------------------------------------------------------------===//
6473
// FatRawBufferCastOp
6574
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)