Skip to content

[MLIR][NVVM] Rename cvt Ops to convert #140868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
let summary = "Convert the given float input to TF32";
let description = [{
This Op converts the given f32 input to tf32.
Expand Down Expand Up @@ -1062,24 +1062,24 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
}];

string llvmBuilder = [{
auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
auto intId = NVVM::ConvertFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
$res = createIntrinsicCall(builder, intId, {$src});
}];
}

def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;

def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind",
[CVTFP6E2M3, CVTFP6E3M2]> {
def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
[ConvertFP6E2M3, ConvertFP6E3M2]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
This Op converts each of the given float inputs to the specified fp6 type.
Expand All @@ -1099,19 +1099,19 @@ def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {

let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP6TypeAttr:$type,
ConvertFP6TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type,
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu);
auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand All @@ -1121,20 +1121,20 @@ def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
}];
}

def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;

def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
[ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
This Op converts each of the given float inputs to the specified fp8 type.
Expand All @@ -1155,7 +1155,7 @@ def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
ConvertFP8TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
Expand All @@ -1164,14 +1164,14 @@ def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand All @@ -1181,7 +1181,7 @@ def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
}];
}

def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
let summary = "Convert an f16x2 input to f8x2";
let description = [{
This Op converts the given f16 inputs in an f16x2 vector to the specified
Expand All @@ -1203,18 +1203,18 @@ def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [F16]>:$a,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand All @@ -1224,7 +1224,7 @@ def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
}];
}

def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let summary = "Convert a pair of bf16 inputs to f8x2";
let description = [{
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
Expand All @@ -1246,7 +1246,7 @@ def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
Expand All @@ -1258,7 +1258,7 @@ def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
}];

string llvmBuilder = [{
auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand Down
63 changes: 32 additions & 31 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
getLoc());
}

LogicalResult CvtFloatToTF32Op::verify() {
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
case RndMode::RNA:
Expand All @@ -129,12 +129,12 @@ LogicalResult CvtFloatToTF32Op::verify() {
break;
default:
return emitError(
"Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
"Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
}
return success();
}

LogicalResult CvtF32x2ToF8x2Op::verify() {
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;

Expand All @@ -146,16 +146,16 @@ LogicalResult CvtF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();

switch (getType()) {
case CVTFP8Type::E4M3:
case CVTFP8Type::E5M2:
case ConvertFP8Type::E4M3:
case ConvertFP8Type::E5M2:
if (!isRoundingModeRN)
return emitOpError("Only RN rounding mode is supported for conversions "
"from f32x2 to .e4m3x2 or .e5m2x2 types");
if (!isSatFinite)
return emitOpError("Only SATFINITE saturation mode is supported for "
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
break;
case CVTFP8Type::UE8M0:
case ConvertFP8Type::UE8M0:
if (!(isRoundingModeRZ || isRoundingModeRP))
return emitOpError("Only RZ or RP rounding modes are supported for "
"conversions from f32x2 to .ue8m0x2 type");
Expand All @@ -166,18 +166,18 @@ LogicalResult CvtF32x2ToF8x2Op::verify() {
return success();
}

LogicalResult CvtF16x2ToF8x2Op::verify() {
if (getType() == CVTFP8Type::UE8M0)
LogicalResult ConvertF16x2ToF8x2Op::verify() {
if (getType() == ConvertFP8Type::UE8M0)
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
"conversions from f16x2 to f8x2.");

return success();
}

LogicalResult CvtBF16x2ToF8x2Op::verify() {
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;

if (getType() != CVTFP8Type::UE8M0)
if (getType() != ConvertFP8Type::UE8M0)
return emitOpError(
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");

Expand Down Expand Up @@ -1336,9 +1336,9 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
: CVT_F2TF32_ID_IMPL(rnd, relu, )

llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
llvm::Intrinsic::ID
ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
using RndMode = NVVM::FPRoundingMode;
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
Expand All @@ -1357,14 +1357,15 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite

llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
bool hasRelu) {
llvm::Intrinsic::ID
ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
switch (type) {
case NVVM::CVTFP6Type::E2M3:
case NVVM::ConvertFP6Type::E2M3:
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
case NVVM::CVTFP6Type::E3M2:
case NVVM::ConvertFP6Type::E3M2:
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
}
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
}

#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
Expand All @@ -1375,20 +1376,20 @@ llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn

llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
llvm::Intrinsic::ID
ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);

switch (type) {
case NVVM::CVTFP8Type::E4M3:
case NVVM::ConvertFP8Type::E4M3:
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
case NVVM::CVTFP8Type::E5M2:
case NVVM::ConvertFP8Type::E5M2:
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
case NVVM::CVTFP8Type::UE8M0:
case NVVM::ConvertFP8Type::UE8M0:
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
Expand All @@ -1401,15 +1402,15 @@ llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn

llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
bool hasRelu) {
llvm::Intrinsic::ID
ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
switch (type) {
case NVVM::CVTFP8Type::E4M3:
case NVVM::ConvertFP8Type::E4M3:
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
case NVVM::CVTFP8Type::E5M2:
case NVVM::ConvertFP8Type::E5M2:
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
default:
llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
}
}

Expand All @@ -1418,8 +1419,8 @@ llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd

llvm::Intrinsic::ID
CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat) {
ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case NVVM::FPRoundingMode::RZ:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
llvm.return
}
Loading