Skip to content

Commit 0c2a6f2

Browse files
authored
[mlir][x86vector] Simplify intrinsic generation (llvm#133692)
Replaces separate x86vector named intrinsic operations with direct calls to LLVM intrinsic functions. This rework reduces the number of named ops leaving only high-level MLIR equivalents of whole intrinsic classes e.g., variants of AVX512 dot on BF16 inputs. Dialect conversion applies LLVM intrinsic name mangling further simplifying lowering logic. The separate conversion step translating x86vector intrinsics into LLVM IR is also eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure. RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
1 parent e98d138 commit 0c2a6f2

File tree

16 files changed

+367
-566
lines changed

16 files changed

+367
-566
lines changed
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect(X86Vector x86vector)
22
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
33

4-
set(LLVM_TARGET_DEFINITIONS X86Vector.td)
5-
mlir_tablegen(X86VectorConversions.inc -gen-llvmir-conversions)
6-
add_public_tablegen_target(MLIRX86VectorConversionsIncGen)
4+
add_mlir_interface(X86VectorInterfaces)
5+
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 97 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
1818
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+
include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
1920

2021
//===----------------------------------------------------------------------===//
2122
// X86Vector dialect definition
@@ -34,30 +35,12 @@ def X86Vector_Dialect : Dialect {
3435
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
3536
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
3637

37-
// Intrinsic operation used during lowering to LLVM IR.
38-
class AVX512_IntrOp<string mnemonic, int numResults,
39-
list<Trait> traits = [],
40-
string extension = ""> :
41-
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42-
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43-
[], [], traits, numResults>;
44-
45-
// Defined by first result overload. May have to be extended for other
46-
// instructions in the future.
47-
class AVX512_IntrOverloadedOp<string mnemonic,
48-
list<Trait> traits = [],
49-
string extension = ""> :
50-
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51-
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52-
/*list<int> overloadedResults=*/[0],
53-
/*list<int> overloadedOperands=*/[],
54-
traits, /*numResults=*/1>;
55-
5638
//----------------------------------------------------------------------------//
5739
// MaskCompressOp
5840
//----------------------------------------------------------------------------//
5941

6042
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
43+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
6144
// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
6245
// then be removed from assemblyFormat.
6346
AllTypesMatch<["a", "dst"]>,
@@ -91,28 +74,25 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
9174
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
9275
" `:` type($dst) (`,` type($src)^)?";
9376
let hasVerifier = 1;
94-
}
9577

96-
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97-
Pure,
98-
AllTypesMatch<["a", "src", "res"]>,
99-
TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100-
"res", "k",
101-
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102-
"IntegerType::get($_self.getContext(), 1))">]> {
103-
let arguments = (ins VectorOfLengthAndType<[16, 8],
104-
[F32, I32, F64, I64]>:$a,
105-
VectorOfLengthAndType<[16, 8],
106-
[F32, I32, F64, I64]>:$src,
107-
VectorOfLengthAndType<[16, 8],
108-
[I1]>:$k);
78+
let extraClassDefinition = [{
79+
std::string $cppClass::getIntrinsicName() {
80+
// Call the baseline overloaded intrisic.
81+
// Final overload name mangling is resolved by the created function call.
82+
return "llvm.x86.avx512.mask.compress";
83+
}
84+
}];
85+
let extraClassDeclaration = [{
86+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
87+
}];
10988
}
11089

11190
//----------------------------------------------------------------------------//
11291
// MaskRndScaleOp
11392
//----------------------------------------------------------------------------//
11493

11594
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
95+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
11696
AllTypesMatch<["src", "a", "dst"]>,
11797
TypesMatchWith<"imm has the same number of bits as elements in dst",
11898
"dst", "imm",
@@ -142,33 +122,28 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
142122
let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143123
let assemblyFormat =
144124
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145-
}
146-
147-
def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148-
Pure,
149-
AllTypesMatch<["src", "a", "res"]>]> {
150-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151-
I32:$k,
152-
VectorOfLengthAndType<[16], [F32]>:$a,
153-
I16:$imm,
154-
LLVM_Type:$rounding);
155-
}
156125

157-
def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158-
Pure,
159-
AllTypesMatch<["src", "a", "res"]>]> {
160-
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161-
I32:$k,
162-
VectorOfLengthAndType<[8], [F64]>:$a,
163-
I8:$imm,
164-
LLVM_Type:$rounding);
126+
let extraClassDefinition = [{
127+
std::string $cppClass::getIntrinsicName() {
128+
std::string intr = "llvm.x86.avx512.mask.rndscale";
129+
VectorType vecType = getSrc().getType();
130+
Type elemType = vecType.getElementType();
131+
intr += ".";
132+
intr += elemType.isF32() ? "ps" : "pd";
133+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
134+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
135+
intr += "." + std::to_string(opBitWidth);
136+
return intr;
137+
}
138+
}];
165139
}
166140

167141
//----------------------------------------------------------------------------//
168142
// MaskScaleFOp
169143
//----------------------------------------------------------------------------//
170144

171145
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
146+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
172147
AllTypesMatch<["src", "a", "b", "dst"]>,
173148
TypesMatchWith<"k has the same number of bits as elements in dst",
174149
"dst", "k",
@@ -199,33 +174,28 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
199174
// Fully specified by traits.
200175
let assemblyFormat =
201176
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202-
}
203-
204-
def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205-
Pure,
206-
AllTypesMatch<["src", "a", "b", "res"]>]> {
207-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208-
VectorOfLengthAndType<[16], [F32]>:$a,
209-
VectorOfLengthAndType<[16], [F32]>:$b,
210-
I16:$k,
211-
LLVM_Type:$rounding);
212-
}
213177

214-
def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215-
Pure,
216-
AllTypesMatch<["src", "a", "b", "res"]>]> {
217-
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218-
VectorOfLengthAndType<[8], [F64]>:$a,
219-
VectorOfLengthAndType<[8], [F64]>:$b,
220-
I8:$k,
221-
LLVM_Type:$rounding);
178+
let extraClassDefinition = [{
179+
std::string $cppClass::getIntrinsicName() {
180+
std::string intr = "llvm.x86.avx512.mask.scalef";
181+
VectorType vecType = getSrc().getType();
182+
Type elemType = vecType.getElementType();
183+
intr += ".";
184+
intr += elemType.isF32() ? "ps" : "pd";
185+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
186+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
187+
intr += "." + std::to_string(opBitWidth);
188+
return intr;
189+
}
190+
}];
222191
}
223192

224193
//----------------------------------------------------------------------------//
225194
// Vp2IntersectOp
226195
//----------------------------------------------------------------------------//
227196

228197
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
198+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
229199
AllTypesMatch<["a", "b"]>,
230200
TypesMatchWith<"k1 has the same number of bits as elements in a",
231201
"a", "k1",
@@ -260,25 +230,28 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
260230
);
261231
let assemblyFormat =
262232
"$a `,` $b attr-dict `:` type($a)";
263-
}
264-
265-
def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266-
Pure]> {
267-
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268-
VectorOfLengthAndType<[16], [I32]>:$b);
269-
}
270233

271-
def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272-
Pure]> {
273-
let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274-
VectorOfLengthAndType<[8], [I64]>:$b);
234+
let extraClassDefinition = [{
235+
std::string $cppClass::getIntrinsicName() {
236+
std::string intr = "llvm.x86.avx512.vp2intersect";
237+
VectorType vecType = getA().getType();
238+
Type elemType = vecType.getElementType();
239+
intr += ".";
240+
intr += elemType.isInteger(32) ? "d" : "q";
241+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
242+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
243+
intr += "." + std::to_string(opBitWidth);
244+
return intr;
245+
}
246+
}];
275247
}
276248

277249
//----------------------------------------------------------------------------//
278250
// Dot BF16
279251
//----------------------------------------------------------------------------//
280252

281253
def DotBF16Op : AVX512_Op<"dot", [Pure,
254+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
282255
AllTypesMatch<["a", "b"]>,
283256
AllTypesMatch<["src", "dst"]>,
284257
TypesMatchWith<"`a` has twice an many elements as `src`",
@@ -299,7 +272,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
299272

300273
Example:
301274
```mlir
302-
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
275+
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303276
```
304277
}];
305278
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -309,43 +282,25 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
309282
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310283
let assemblyFormat =
311284
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312-
}
313285

314-
def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315-
AllTypesMatch<["a", "b"]>,
316-
AllTypesMatch<["src", "res"]>],
317-
/*extension=*/"bf16"> {
318-
let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319-
VectorOfLengthAndType<[8], [BF16]>:$a,
320-
VectorOfLengthAndType<[8], [BF16]>:$b);
321-
let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322-
}
323-
324-
def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325-
AllTypesMatch<["a", "b"]>,
326-
AllTypesMatch<["src", "res"]>],
327-
/*extension=*/"bf16"> {
328-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329-
VectorOfLengthAndType<[16], [BF16]>:$a,
330-
VectorOfLengthAndType<[16], [BF16]>:$b);
331-
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332-
}
333-
334-
def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335-
AllTypesMatch<["a", "b"]>,
336-
AllTypesMatch<["src", "res"]>],
337-
/*extension=*/"bf16"> {
338-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339-
VectorOfLengthAndType<[32], [BF16]>:$a,
340-
VectorOfLengthAndType<[32], [BF16]>:$b);
341-
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
286+
let extraClassDefinition = [{
287+
std::string $cppClass::getIntrinsicName() {
288+
std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
289+
VectorType vecType = getSrc().getType();
290+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
291+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
292+
intr += "." + std::to_string(opBitWidth);
293+
return intr;
294+
}
295+
}];
342296
}
343297

344298
//----------------------------------------------------------------------------//
345299
// Convert packed F32 to packed BF16
346300
//----------------------------------------------------------------------------//
347301

348302
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
303+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
349304
AllElementCountsMatch<["a", "dst"]>]> {
350305
let summary = "Convert packed F32 to packed BF16 Data.";
351306
let description = [{
@@ -367,18 +322,17 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
367322
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368323
let assemblyFormat =
369324
"$a attr-dict `:` type($a) `->` type($dst)";
370-
}
371-
372-
def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373-
/*extension=*/"bf16"> {
374-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375-
let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376-
}
377325

378-
def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379-
/*extension=*/"bf16"> {
380-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381-
let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
326+
let extraClassDefinition = [{
327+
std::string $cppClass::getIntrinsicName() {
328+
std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
329+
VectorType vecType = getA().getType();
330+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
331+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
332+
intr += "." + std::to_string(opBitWidth);
333+
return intr;
334+
}
335+
}];
382336
}
383337

384338
//===----------------------------------------------------------------------===//
@@ -395,33 +349,32 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
395349
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
396350
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
397351

398-
// Intrinsic operation used during lowering to LLVM IR.
399-
class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
400-
LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
401-
"x86_avx_" # !subst(".", "_", mnemonic),
402-
[], [], traits, numResults>;
403-
404352
//----------------------------------------------------------------------------//
405353
// AVX Rsqrt
406354
//----------------------------------------------------------------------------//
407355

408-
def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
356+
def RsqrtOp : AVX_Op<"rsqrt", [Pure,
357+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
358+
SameOperandsAndResultType]> {
409359
let summary = "Rsqrt";
410360
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
411361
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
412362
let assemblyFormat = "$a attr-dict `:` type($a)";
413-
}
414363

415-
def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
416-
SameOperandsAndResultType]> {
417-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
364+
let extraClassDefinition = [{
365+
std::string $cppClass::getIntrinsicName() {
366+
return "llvm.x86.avx.rsqrt.ps.256";
367+
}
368+
}];
418369
}
419370

420371
//----------------------------------------------------------------------------//
421372
// AVX Dot
422373
//----------------------------------------------------------------------------//
423374

424-
def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
375+
def DotOp : AVX_LowOp<"dot", [Pure,
376+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
377+
SameOperandsAndResultType]> {
425378
let summary = "Dot";
426379
let description = [{
427380
Computes the 4-way dot products of the lower and higher parts of the source
@@ -443,13 +396,16 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
443396
VectorOfLengthAndType<[8], [F32]>:$b);
444397
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
445398
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
446-
}
447399

448-
def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
449-
AllTypesMatch<["a", "b", "res"]>]> {
450-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
451-
VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
452-
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
400+
let extraClassDefinition = [{
401+
std::string $cppClass::getIntrinsicName() {
402+
// Only one variant is supported right now - no extra mangling.
403+
return "llvm.x86.avx.dp.ps.256";
404+
}
405+
}];
406+
let extraClassDeclaration = [{
407+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
408+
}];
453409
}
454410

455411
#endif // X86VECTOR_OPS

mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
#include "mlir/IR/Dialect.h"
1919
#include "mlir/IR/OpDefinition.h"
2020
#include "mlir/IR/OpImplementation.h"
21+
#include "mlir/IR/PatternMatch.h"
2122
#include "mlir/Interfaces/InferTypeOpInterface.h"
2223
#include "mlir/Interfaces/SideEffectInterfaces.h"
2324

25+
/// Include the generated interface declarations.
26+
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
27+
2428
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc"
2529

2630
#define GET_OP_CLASSES

0 commit comments

Comments
 (0)