Skip to content

[mlir][AMDGPU] Add scaled floating point conversion ops #141554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,38 @@ def AMDGPU_ExtPackedFp8Op :
}];
}

def AMDGPU_ScaledExtPackedOp
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8],
[F4E2M1FN]>]>:$source,
F32:$scale,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
Results<(
outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>,
FixedVectorOfLengthAndType<[2], [F16]>,
FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> {
let summary = "Extend a vector of packed floating point values";

let description = [{
Extend and scale two packed floats in `source[index]` to two floats and
return them.
Comment on lines +130 to +131
Copy link
Contributor

@umangyadav umangyadav Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this a bit confusing source[index] would only point to one element. How do we get two floats in return ?

Ans: It is selecting byte of a 32-bit word. For F4 each byte would be two floats.


This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
extending 8-bit floats (which are, currently, the only way to work with
this operation) take packed vectors of 2 such floats.

If the passed-in vector has fewer than two elements, or the input is scalar,
the remaining values in the <2 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
}];
}

def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
Expand Down Expand Up @@ -139,6 +171,36 @@ def AMDGPU_PackedTrunc2xFp8Op :
let hasVerifier = 1;
}

def AMDGPU_PackedScaledTruncOp
: AMDGPU_Op<"packed_scaled_trunc", [Pure]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
for the ext it is named as scaled_ext_packed
here it is packed_scaled_trunc. I find it better to call it scaled_trunc_packed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That follows existing convention in the dialect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have ext_packed (the things being extended are packed) but packed_trunc (the result is packed)

Arguments<(ins VectorOfLengthAndType<[1, 2], [F32, F16, BF16]>:$source,
F32:$scale,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index,
Optional<AnyTypeOf<
[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>>:$existing)>,
Results<(
outs AnyTypeOf<[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> {
let summary = "Round two floats into a packed vector of floats";
let description = [{
Scale and round the inputs `source` (which is undefined if not
specified) into the low or high word (bottom two or top two) elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a selector for selecting low or high word ?
What's the role of index attribute here ? Can you explain ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is either low/high half or byte 0-3 depending, probably

of the returned vector, keeping the other two elements of `existing`
unchanged if present (or undefined if it was not passed in).

The reason for this odd signature is that AMD GPUs cannot easily work with
sub-registers, and so the conversion intrinsics take 32-bit wide
packed vectors of float values.
}];
let assemblyFormat = [{
attr-dict $source `into` ($existing^):(`undef`)? `[` $index `]`
`,` $scale
`:` type($source) `to` type($res) (`into` type($existing)^)?
}];
let hasVerifier = 1;
}

def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
Expand Down
189 changes: 188 additions & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -1174,6 +1175,32 @@ struct PackedStochRoundFp8OpLowering final
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct ScaledExtPackedOpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct PackedScaledTruncOpLowering final
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // end namespace

LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Expand Down Expand Up @@ -1230,6 +1257,165 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}

LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

Value source = adaptor.getSource();
Value scale = adaptor.getScale();

VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();
VectorType destVecType = cast<VectorType>(op.getResult().getType());
Type destElemType = destVecType.getElementType();

VectorType packedVecType;
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
} else {
llvm_unreachable("invalid element type for scaled ext");
}

// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
Copy link
Contributor

@umangyadav umangyadav Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed later that for the FP8/BF8 -> FP32/FP16 we have non-packed instructions as well.

https://github.com/llvm/llvm-project/blob/4f8187c0dc6e7a818ebf3272a0c022203f901e96/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td#L976C7-L976C28

We don't need to go through padding for those unnecessarily.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to ignore those for now. Especially since the FP16 variants introduce a second index attribute (one for the packed input and another one for the output (index into v2f16)).

if (!sourceVecType) {
longVec = rewriter.create<LLVM::InsertElementOp>(
loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
longVec =
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);

if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
Comment on lines +1307 to +1335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
if (isa<Float8E5M2Type>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float8E4M3FNType>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else {
return failure();
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this makes the code less readable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, that what i suggested is feels better to read. Anyways it's personal opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a static std::tuple<OperationName, StringAttr, IntegerType> getConversionOpNameAndSelectorArg(Type inType, Type outType); that'd return the information needed to construct this operation generically. That'd save us all the duplicated op, destVecType, i32Source, scale, op.getIndex() sections

Basically, have we considered doing it the way the MFMA/WMMA code handles all this?

(And, based on my experience with said code, the nested ifs Umang proposed end up a bit easier to work with ... though, in retrospect, a lot of those could've been a llvm::TypeSwitch.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Said note is non-blocknig if you'd really rather not rewrite all this again, but I thought I'd raise it.)


return success();
}

LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v2i16 = getTypeConverter()->convertType(
VectorType::get(2, rewriter.getI16Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
Type sourceElemType = sourceVecType.getElementType();

Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;

Value source = adaptor.getSource();
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
else
existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);

if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
source = rewriter.create<LLVM::ZeroOp>(loc, v2);
source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
}

Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
}

Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
Comment on lines +1384 to +1412
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
if (isa<Float8E5M2Type>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float8E4M3FNType>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float4E2M1FNType>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else {
return failure();
}


result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}

LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1547,7 +1733,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ LogicalResult PackedStochRoundFp8Op::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// mxfp float ops
//===----------------------------------------------------------------------===//
LogicalResult PackedScaledTruncOp::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
return success();
}

//===----------------------------------------------------------------------===//
// FatRawBufferCastOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading