|
24 | 24 | #include "llvm/ADT/STLExtras.h"
|
25 | 25 | #include "llvm/ADT/TypeSwitch.h"
|
26 | 26 | #include "llvm/Support/Casting.h"
|
| 27 | +#include "llvm/Support/ErrorHandling.h" |
27 | 28 | #include <optional>
|
28 | 29 |
|
29 | 30 | namespace mlir {
|
@@ -1174,6 +1175,32 @@ struct PackedStochRoundFp8OpLowering final
|
1174 | 1175 | PackedStochRoundFp8OpAdaptor adaptor,
|
1175 | 1176 | ConversionPatternRewriter &rewriter) const override;
|
1176 | 1177 | };
|
| 1178 | + |
| 1179 | +struct ScaledExtPackedOpLowering final |
| 1180 | + : public ConvertOpToLLVMPattern<ScaledExtPackedOp> { |
| 1181 | + ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
| 1182 | + : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter), |
| 1183 | + chipset(chipset) {} |
| 1184 | + Chipset chipset; |
| 1185 | + |
| 1186 | + LogicalResult |
| 1187 | + matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, |
| 1188 | + ConversionPatternRewriter &rewriter) const override; |
| 1189 | +}; |
| 1190 | + |
| 1191 | +struct PackedScaledTruncOpLowering final |
| 1192 | + : public ConvertOpToLLVMPattern<PackedScaledTruncOp> { |
| 1193 | + PackedScaledTruncOpLowering(const LLVMTypeConverter &converter, |
| 1194 | + Chipset chipset) |
| 1195 | + : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter), |
| 1196 | + chipset(chipset) {} |
| 1197 | + Chipset chipset; |
| 1198 | + |
| 1199 | + LogicalResult |
| 1200 | + matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, |
| 1201 | + ConversionPatternRewriter &rewriter) const override; |
| 1202 | +}; |
| 1203 | + |
1177 | 1204 | } // end namespace
|
1178 | 1205 |
|
1179 | 1206 | LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
@@ -1230,6 +1257,165 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
1230 | 1257 | return success();
|
1231 | 1258 | }
|
1232 | 1259 |
|
| 1260 | +LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( |
| 1261 | + ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, |
| 1262 | + ConversionPatternRewriter &rewriter) const { |
| 1263 | + Location loc = op.getLoc(); |
| 1264 | + if (chipset != kGfx950) |
| 1265 | + return rewriter.notifyMatchFailure( |
| 1266 | + loc, "Scaled fp conversion instructions are not available on target " |
| 1267 | + "architecture and their emulation is not implemented"); |
| 1268 | + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| 1269 | + |
| 1270 | + Value source = adaptor.getSource(); |
| 1271 | + Value scale = adaptor.getScale(); |
| 1272 | + |
| 1273 | + VectorType sourceVecType = cast<VectorType>(op.getSource().getType()); |
| 1274 | + Type sourceElemType = sourceVecType.getElementType(); |
| 1275 | + VectorType destVecType = cast<VectorType>(op.getResult().getType()); |
| 1276 | + Type destElemType = destVecType.getElementType(); |
| 1277 | + |
| 1278 | + VectorType packedVecType; |
| 1279 | + if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) { |
| 1280 | + VectorType v4i8 = VectorType::get(4, rewriter.getI8Type()); |
| 1281 | + packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8)); |
| 1282 | + } else if (isa<Float4E2M1FNType>(sourceElemType)) { |
| 1283 | + VectorType v8i4 = VectorType::get(8, rewriter.getI4Type()); |
| 1284 | + packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4)); |
| 1285 | + } else { |
| 1286 | + llvm_unreachable("invalid element type for scaled ext"); |
| 1287 | + } |
| 1288 | + |
| 1289 | + // Extend to a packedVectorType |
| 1290 | + if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { |
| 1291 | + Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType); |
| 1292 | + if (!sourceVecType) { |
| 1293 | + longVec = rewriter.create<LLVM::InsertElementOp>( |
| 1294 | + loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
| 1295 | + } else { |
| 1296 | + for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
| 1297 | + Value idx = createI32Constant(rewriter, loc, i); |
| 1298 | + Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); |
| 1299 | + longVec = |
| 1300 | + rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); |
| 1301 | + } |
| 1302 | + } |
| 1303 | + source = longVec; |
| 1304 | + } |
| 1305 | + Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); |
| 1306 | + |
| 1307 | + if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32()) |
| 1308 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>( |
| 1309 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1310 | + else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16()) |
| 1311 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>( |
| 1312 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1313 | + else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16()) |
| 1314 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>( |
| 1315 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1316 | + else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32()) |
| 1317 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>( |
| 1318 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1319 | + else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16()) |
| 1320 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>( |
| 1321 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1322 | + else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16()) |
| 1323 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>( |
| 1324 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1325 | + else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32()) |
| 1326 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>( |
| 1327 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1328 | + else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16()) |
| 1329 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>( |
| 1330 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1331 | + else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16()) |
| 1332 | + rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>( |
| 1333 | + op, destVecType, i32Source, scale, op.getIndex()); |
| 1334 | + else |
| 1335 | + return failure(); |
| 1336 | + |
| 1337 | + return success(); |
| 1338 | +} |
| 1339 | + |
| 1340 | +LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( |
| 1341 | + PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, |
| 1342 | + ConversionPatternRewriter &rewriter) const { |
| 1343 | + Location loc = op.getLoc(); |
| 1344 | + if (chipset != kGfx950) |
| 1345 | + return rewriter.notifyMatchFailure( |
| 1346 | + loc, "Scaled fp conversion instructions are not available on target " |
| 1347 | + "architecture and their emulation is not implemented"); |
| 1348 | + Type v2i16 = getTypeConverter()->convertType( |
| 1349 | + VectorType::get(2, rewriter.getI16Type())); |
| 1350 | + Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
| 1351 | + |
| 1352 | + Type resultType = op.getResult().getType(); |
| 1353 | + Type resultElemType = getElementTypeOrSelf(resultType); |
| 1354 | + VectorType sourceVecType = cast<VectorType>(op.getSource().getType()); |
| 1355 | + Type sourceElemType = sourceVecType.getElementType(); |
| 1356 | + |
| 1357 | + Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16; |
| 1358 | + |
| 1359 | + Value source = adaptor.getSource(); |
| 1360 | + Value scale = adaptor.getScale(); |
| 1361 | + Value existing = adaptor.getExisting(); |
| 1362 | + if (existing) |
| 1363 | + existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing); |
| 1364 | + else |
| 1365 | + existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType); |
| 1366 | + |
| 1367 | + if (sourceVecType.getNumElements() < 2) { |
| 1368 | + Value c0 = createI32Constant(rewriter, loc, 0); |
| 1369 | + Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); |
| 1370 | + VectorType v2 = VectorType::get(2, sourceElemType); |
| 1371 | + source = rewriter.create<LLVM::ZeroOp>(loc, v2); |
| 1372 | + source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0); |
| 1373 | + } |
| 1374 | + |
| 1375 | + Value sourceA, sourceB; |
| 1376 | + if (sourceElemType.isF32()) { |
| 1377 | + Value c0 = createI32Constant(rewriter, loc, 0); |
| 1378 | + Value c1 = createI32Constant(rewriter, loc, 1); |
| 1379 | + sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); |
| 1380 | + sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1); |
| 1381 | + } |
| 1382 | + |
| 1383 | + Value result; |
| 1384 | + if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType)) |
| 1385 | + result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>( |
| 1386 | + loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); |
| 1387 | + else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType)) |
| 1388 | + result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>( |
| 1389 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1390 | + else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType)) |
| 1391 | + result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>( |
| 1392 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1393 | + else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType)) |
| 1394 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>( |
| 1395 | + loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); |
| 1396 | + else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType)) |
| 1397 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>( |
| 1398 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1399 | + else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType)) |
| 1400 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>( |
| 1401 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1402 | + else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType)) |
| 1403 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>( |
| 1404 | + loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); |
| 1405 | + else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType)) |
| 1406 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>( |
| 1407 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1408 | + else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType)) |
| 1409 | + result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>( |
| 1410 | + loc, intResultType, existing, source, scale, op.getIndex()); |
| 1411 | + else |
| 1412 | + return failure(); |
| 1413 | + |
| 1414 | + result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| 1415 | + op, getTypeConverter()->convertType(resultType), result); |
| 1416 | + return success(); |
| 1417 | +} |
| 1418 | + |
1233 | 1419 | LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
|
1234 | 1420 | PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
|
1235 | 1421 | ConversionPatternRewriter &rewriter) const {
|
@@ -1547,7 +1733,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
1547 | 1733 | ROCDL::RawPtrBufferAtomicCmpSwap>,
|
1548 | 1734 | AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
|
1549 | 1735 | MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
|
1550 |
| - ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, |
| 1736 | + ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, |
| 1737 | + PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, |
1551 | 1738 | PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
|
1552 | 1739 | chipset);
|
1553 | 1740 | patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
|
|
0 commit comments