-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][spirv] Add bfloat16 support #141458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Darren Wihandi (fairywreath) ChangesAdds bf16 support to SPIRV by using the Remaining TODO:
Full diff: https://github.com/llvm/llvm-project/pull/141458.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..daa1b2b328115 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];
let arguments = (ins
- SPIRV_VectorOf<SPIRV_Float>:$vector1,
- SPIRV_VectorOf<SPIRV_Float>:$vector2
+ SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
+ SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
);
let results = (outs
- SPIRV_Float:$result
+ SPIRV_FloatOrBFloat16:$result
);
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8fd533db83d9a..5d4469954e5b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
+def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
- SPV_KHR_cooperative_matrix,
+ SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
+def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
+def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
+ list<Availability> availability = [
+ Extension<[SPV_KHR_bfloat16]>
+ ];
+}
def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
- SPIRV_C_CacheControlsINTEL
+ SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_BFloat16TypeKHR]>
+ ];
+}
+def SPIRV_FPEncodingAttr :
+ SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
+ SPIRV_FPE_BFloat16KHR
+ ]>;
+
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
+def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
- [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
+ [SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4194,9 +4224,9 @@ def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
+ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index b05ee0251df5b..29571cf138ebf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
// -----
-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
// -----
-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
+def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----
def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----
def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----
def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
- SPIRV_Float,
- SPIRV_Float,
+ SPIRV_FloatOrBFloat16,
+ SPIRV_FloatOrBFloat16,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 0cf5f0823be63..a21acef1c4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
- if (type.isBF16()) {
- parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
- return Type();
- }
+ // TODO: All float types are allowed for now, but this should be fixed.
} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..5da3164ad4d14 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}
bool ScalarType::isValid(FloatType type) {
- return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+ return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}
bool ScalarType::isValid(IntegerType type) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..b43f22db55a2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
+ if (floatType.isBF16()) {
+ operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
+ }
return success();
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2..2e34c9ff54012 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -206,18 +206,6 @@ func.func @float64(%arg0: f64) { return }
// -----
-// Check that bf16 is not supported.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-NOT: spirv.func @bf16_type
-func.func @bf16_type(%arg0: bf16) { return }
-
-} // end module
-
-// -----
-
//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index 2d0c86e08de5a..301a5bab9ab1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -265,6 +265,13 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ return %0 : bf16
+}
+
+// -----
+
// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -283,7 +290,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or bfloat16 type values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 34d0109e6bb44..4480a1f3720f2 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
// -----
+func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToS %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertFToU
//===----------------------------------------------------------------------===//
@@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
// -----
+func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
+ %0 = spirv.ConvertFToU %arg0 : bf16 to i32
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
@@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertSToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertUToF
//===----------------------------------------------------------------------===//
@@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {
// -----
+func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
+ // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
+ %0 = spirv.ConvertUToF %arg0 : i32 to bf16
+ spirv.ReturnValue %0 : bf16
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.FConvert
//===----------------------------------------------------------------------===//
@@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
// -----
+func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
+ %0 = spirv.FConvert %arg0 : bf16 to f32
+ spirv.ReturnValue %0 : f32
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
+ %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
+ spirv.ReturnValue %0 : vector<3xbf16>
+}
+
+// -----
+
+func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+ spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..8929e63639c97 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -31,6 +31,15 @@ spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuf
spirv.Return
}
+// CHECK-LABEL: @cooperative_matrix_load_bf16
+spirv.func @cooperative_matrix_load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
+ // CHECK-SAME: : !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
+ spirv.Return
+}
+
// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
@@ -225,6 +234,26 @@ spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgro
spirv.Return
}
+spirv.func @cooperative_matrix_muladd_bf16_bf16(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+spirv.func @cooperative_matrix_muladd_bf16_f32(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>) "None" {
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
+ !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
+ !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index b63a08d96e6af..a81fe72a8362e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -57,11 +57,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()
// -----
-// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
-func.func private @bf16_type(!spirv.array<4xbf16>) -> ()
-
-// -----
-
// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func.func private @i256_type(!spirv.array<4xi256>) -> ()
|
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | ||
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | ||
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>; | ||
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | ||
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; | ||
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can argue that bf16 should be part of SPIRV_Float
. The problem here is that bf16 usage in SPIRV is very limited while SPIRV_Float
(i.e regular floats) is used widely in the codebase for other ops(eg. texture sampling and regular arithmetic insts). I chose to leave SPIRV_Float
to minimize the amount of changes(and to not introduce something like SPIRV_ArithmeticFloat
). Please let me know if you think there is a cleaner solution to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this, it might be better to leave them alone. what about something like this:
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>;
def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>;
// Use this type for all kinds of floats.
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>;
.....
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
......
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
.........
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); | ||
return Type(); | ||
} | ||
// TODO: All float types are allowed for now, but this should be fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will address this in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate what needs to be fixed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for submitting this PR! Adding bf16 support was somewhere on my TODO list, so I'm glad to see this patch. I've left few comments - sorry for nitpicks! :)
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); | ||
return Type(); | ||
} | ||
// TODO: All float types are allowed for now, but this should be fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate what needs to be fixed here?
@@ -37,20 +37,43 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type, | |||
let assemblyFormat = "operands attr-dict `:` type($result)"; | |||
} | |||
|
|||
class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type, | |||
class SPIRV_ArithmeticWithCoopMatrixBinaryOp<string mnemonic, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether this change should be a part of this PR. I see how it's related, but I think it may make more sense to land adding the bf16 extension first and then have a separate PR for changes to the ops. This would also mean you could address the TODO above as a part of the second PR. So, my suggestion is to have 2 PRs:
- Adding
SPV_KHR_bfloat16
support - Enabling bf16 with coop matrices
But we should let the code owners to make decision on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I have limited the scope of this PR to adding the type and basic cast ops.
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); | ||
return Type(); | ||
} | ||
// TODO: All float types are allowed for now, but this should be fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing this.
Do we emit errors when using unsupported arithmetic ops etc. with bf16? Could we have tests that cover this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for adding this @fairywreath.
I have a few comments.
What is opinion on creating TypeAlias for SPIR-V BF16 (I put a potential example in the comment)?
It has some good nice things that comes with it. The verification can be easier to handle.
But your current approach is also fine, just different perspective.
P.S.> I had downstream solution, never got around to upstream it. You can check it out if it helps: https://github.com/intel/mlir-extensions/blob/main/build_tools/patches/0009-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | ||
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | ||
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>; | ||
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | ||
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; | ||
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this, it might be better to leave them alone. what about something like this:
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>;
def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>;
// Use this type for all kinds of floats.
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>;
.....
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
......
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
.........
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
Thanks for the review! I agree with your suggestions and have changed the definitions accordingly. I have also integrated changes from your downstream solution, specifically the serializer/deserializer changes. The changes from your downstream solution that I did not integrate are the CL ops and the arithmetic ops(which Vulkan does not support for bf16). |
Yes, originally there was only 1 FMul test. I have added more tests to ensure errors are emitted for arithmetic, subgroup ops, and atomics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @fairywreath for addressing the comment.
LGTM.
Please wait for others' approval who reviewed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you for addressing all the comments
Adds bf16 support to SPIRV by using the `SPV_KHR_bfloat16` extension. Only a few operations are supported, including loading from and storing to memory, conversion to/from other types, cooperative matrix operations (including coop matrix arithmetic ops) and dot product support. This PR adds the type definition and implements the basic cast operations. Arithmetic/coop matrix ops will be added in a separate PR.
Adds bf16 support to SPIRV by using the
SPV_KHR_bfloat16
extension. Only a few operations are supported, including loading from and storing to memory, conversion to/from other types, cooperative matrix operations (including coop matrix arithmetic ops) and dot product support.This PR adds the type definition and implements the basic cast operations. Arithmetic/coop matrix ops will be added in a separate PR.