16
16
include "mlir/Interfaces/InferTypeOpInterface.td"
17
17
include "mlir/Interfaces/SideEffectInterfaces.td"
18
18
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19
+ include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
19
20
20
21
//===----------------------------------------------------------------------===//
21
22
// X86Vector dialect definition
@@ -34,30 +35,12 @@ def X86Vector_Dialect : Dialect {
34
35
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
35
36
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
36
37
37
- // Intrinsic operation used during lowering to LLVM IR.
38
- class AVX512_IntrOp<string mnemonic, int numResults,
39
- list<Trait> traits = [],
40
- string extension = ""> :
41
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42
- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43
- [], [], traits, numResults>;
44
-
45
- // Defined by first result overload. May have to be extended for other
46
- // instructions in the future.
47
- class AVX512_IntrOverloadedOp<string mnemonic,
48
- list<Trait> traits = [],
49
- string extension = ""> :
50
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51
- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52
- /*list<int> overloadedResults=*/[0],
53
- /*list<int> overloadedOperands=*/[],
54
- traits, /*numResults=*/1>;
55
-
56
38
//----------------------------------------------------------------------------//
57
39
// MaskCompressOp
58
40
//----------------------------------------------------------------------------//
59
41
60
42
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
43
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
61
44
// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
62
45
// then be removed from assemblyFormat.
63
46
AllTypesMatch<["a", "dst"]>,
@@ -91,28 +74,25 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
91
74
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
92
75
" `:` type($dst) (`,` type($src)^)?";
93
76
let hasVerifier = 1;
94
- }
95
77
96
- def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97
- Pure,
98
- AllTypesMatch<["a", "src", "res"]>,
99
- TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100
- "res", "k",
101
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102
- "IntegerType::get($_self.getContext(), 1))">]> {
103
- let arguments = (ins VectorOfLengthAndType<[16, 8],
104
- [F32, I32, F64, I64]>:$a,
105
- VectorOfLengthAndType<[16, 8],
106
- [F32, I32, F64, I64]>:$src,
107
- VectorOfLengthAndType<[16, 8],
108
- [I1]>:$k);
78
+ let extraClassDefinition = [{
79
+ std::string $cppClass::getIntrinsicName() {
80
+ // Call the baseline overloaded intrisic.
81
+ // Final overload name mangling is resolved by the created function call.
82
+ return "llvm.x86.avx512.mask.compress";
83
+ }
84
+ }];
85
+ let extraClassDeclaration = [{
86
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
87
+ }];
109
88
}
110
89
111
90
//----------------------------------------------------------------------------//
112
91
// MaskRndScaleOp
113
92
//----------------------------------------------------------------------------//
114
93
115
94
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
95
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
116
96
AllTypesMatch<["src", "a", "dst"]>,
117
97
TypesMatchWith<"imm has the same number of bits as elements in dst",
118
98
"dst", "imm",
@@ -142,33 +122,28 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
142
122
let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143
123
let assemblyFormat =
144
124
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145
- }
146
-
147
- def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148
- Pure,
149
- AllTypesMatch<["src", "a", "res"]>]> {
150
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151
- I32:$k,
152
- VectorOfLengthAndType<[16], [F32]>:$a,
153
- I16:$imm,
154
- LLVM_Type:$rounding);
155
- }
156
125
157
- def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158
- Pure,
159
- AllTypesMatch<["src", "a", "res"]>]> {
160
- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161
- I32:$k,
162
- VectorOfLengthAndType<[8], [F64]>:$a,
163
- I8:$imm,
164
- LLVM_Type:$rounding);
126
+ let extraClassDefinition = [{
127
+ std::string $cppClass::getIntrinsicName() {
128
+ std::string intr = "llvm.x86.avx512.mask.rndscale";
129
+ VectorType vecType = getSrc().getType();
130
+ Type elemType = vecType.getElementType();
131
+ intr += ".";
132
+ intr += elemType.isF32() ? "ps" : "pd";
133
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
134
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
135
+ intr += "." + std::to_string(opBitWidth);
136
+ return intr;
137
+ }
138
+ }];
165
139
}
166
140
167
141
//----------------------------------------------------------------------------//
168
142
// MaskScaleFOp
169
143
//----------------------------------------------------------------------------//
170
144
171
145
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
146
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
172
147
AllTypesMatch<["src", "a", "b", "dst"]>,
173
148
TypesMatchWith<"k has the same number of bits as elements in dst",
174
149
"dst", "k",
@@ -199,33 +174,28 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
199
174
// Fully specified by traits.
200
175
let assemblyFormat =
201
176
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202
- }
203
-
204
- def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205
- Pure,
206
- AllTypesMatch<["src", "a", "b", "res"]>]> {
207
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208
- VectorOfLengthAndType<[16], [F32]>:$a,
209
- VectorOfLengthAndType<[16], [F32]>:$b,
210
- I16:$k,
211
- LLVM_Type:$rounding);
212
- }
213
177
214
- def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215
- Pure,
216
- AllTypesMatch<["src", "a", "b", "res"]>]> {
217
- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218
- VectorOfLengthAndType<[8], [F64]>:$a,
219
- VectorOfLengthAndType<[8], [F64]>:$b,
220
- I8:$k,
221
- LLVM_Type:$rounding);
178
+ let extraClassDefinition = [{
179
+ std::string $cppClass::getIntrinsicName() {
180
+ std::string intr = "llvm.x86.avx512.mask.scalef";
181
+ VectorType vecType = getSrc().getType();
182
+ Type elemType = vecType.getElementType();
183
+ intr += ".";
184
+ intr += elemType.isF32() ? "ps" : "pd";
185
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
186
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
187
+ intr += "." + std::to_string(opBitWidth);
188
+ return intr;
189
+ }
190
+ }];
222
191
}
223
192
224
193
//----------------------------------------------------------------------------//
225
194
// Vp2IntersectOp
226
195
//----------------------------------------------------------------------------//
227
196
228
197
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
198
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
229
199
AllTypesMatch<["a", "b"]>,
230
200
TypesMatchWith<"k1 has the same number of bits as elements in a",
231
201
"a", "k1",
@@ -260,25 +230,28 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
260
230
);
261
231
let assemblyFormat =
262
232
"$a `,` $b attr-dict `:` type($a)";
263
- }
264
-
265
- def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266
- Pure]> {
267
- let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268
- VectorOfLengthAndType<[16], [I32]>:$b);
269
- }
270
233
271
- def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272
- Pure]> {
273
- let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274
- VectorOfLengthAndType<[8], [I64]>:$b);
234
+ let extraClassDefinition = [{
235
+ std::string $cppClass::getIntrinsicName() {
236
+ std::string intr = "llvm.x86.avx512.vp2intersect";
237
+ VectorType vecType = getA().getType();
238
+ Type elemType = vecType.getElementType();
239
+ intr += ".";
240
+ intr += elemType.isInteger(32) ? "d" : "q";
241
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
242
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
243
+ intr += "." + std::to_string(opBitWidth);
244
+ return intr;
245
+ }
246
+ }];
275
247
}
276
248
277
249
//----------------------------------------------------------------------------//
278
250
// Dot BF16
279
251
//----------------------------------------------------------------------------//
280
252
281
253
def DotBF16Op : AVX512_Op<"dot", [Pure,
254
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
282
255
AllTypesMatch<["a", "b"]>,
283
256
AllTypesMatch<["src", "dst"]>,
284
257
TypesMatchWith<"`a` has twice an many elements as `src`",
@@ -299,7 +272,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
299
272
300
273
Example:
301
274
```mlir
302
- %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
275
+ %dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303
276
```
304
277
}];
305
278
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -309,43 +282,25 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
309
282
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310
283
let assemblyFormat =
311
284
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312
- }
313
285
314
- def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315
- AllTypesMatch<["a", "b"]>,
316
- AllTypesMatch<["src", "res"]>],
317
- /*extension=*/"bf16"> {
318
- let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319
- VectorOfLengthAndType<[8], [BF16]>:$a,
320
- VectorOfLengthAndType<[8], [BF16]>:$b);
321
- let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322
- }
323
-
324
- def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325
- AllTypesMatch<["a", "b"]>,
326
- AllTypesMatch<["src", "res"]>],
327
- /*extension=*/"bf16"> {
328
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329
- VectorOfLengthAndType<[16], [BF16]>:$a,
330
- VectorOfLengthAndType<[16], [BF16]>:$b);
331
- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332
- }
333
-
334
- def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335
- AllTypesMatch<["a", "b"]>,
336
- AllTypesMatch<["src", "res"]>],
337
- /*extension=*/"bf16"> {
338
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339
- VectorOfLengthAndType<[32], [BF16]>:$a,
340
- VectorOfLengthAndType<[32], [BF16]>:$b);
341
- let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
286
+ let extraClassDefinition = [{
287
+ std::string $cppClass::getIntrinsicName() {
288
+ std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
289
+ VectorType vecType = getSrc().getType();
290
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
291
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
292
+ intr += "." + std::to_string(opBitWidth);
293
+ return intr;
294
+ }
295
+ }];
342
296
}
343
297
344
298
//----------------------------------------------------------------------------//
345
299
// Convert packed F32 to packed BF16
346
300
//----------------------------------------------------------------------------//
347
301
348
302
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
303
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
349
304
AllElementCountsMatch<["a", "dst"]>]> {
350
305
let summary = "Convert packed F32 to packed BF16 Data.";
351
306
let description = [{
@@ -367,18 +322,17 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
367
322
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368
323
let assemblyFormat =
369
324
"$a attr-dict `:` type($a) `->` type($dst)";
370
- }
371
-
372
- def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373
- /*extension=*/"bf16"> {
374
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375
- let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376
- }
377
325
378
- def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379
- /*extension=*/"bf16"> {
380
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381
- let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
326
+ let extraClassDefinition = [{
327
+ std::string $cppClass::getIntrinsicName() {
328
+ std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
329
+ VectorType vecType = getA().getType();
330
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
331
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
332
+ intr += "." + std::to_string(opBitWidth);
333
+ return intr;
334
+ }
335
+ }];
382
336
}
383
337
384
338
//===----------------------------------------------------------------------===//
@@ -395,33 +349,32 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
395
349
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
396
350
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
397
351
398
- // Intrinsic operation used during lowering to LLVM IR.
399
- class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
400
- LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
401
- "x86_avx_" # !subst(".", "_", mnemonic),
402
- [], [], traits, numResults>;
403
-
404
352
//----------------------------------------------------------------------------//
405
353
// AVX Rsqrt
406
354
//----------------------------------------------------------------------------//
407
355
408
- def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
356
+ def RsqrtOp : AVX_Op<"rsqrt", [Pure,
357
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
358
+ SameOperandsAndResultType]> {
409
359
let summary = "Rsqrt";
410
360
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
411
361
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
412
362
let assemblyFormat = "$a attr-dict `:` type($a)";
413
- }
414
363
415
- def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
416
- SameOperandsAndResultType]> {
417
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
364
+ let extraClassDefinition = [{
365
+ std::string $cppClass::getIntrinsicName() {
366
+ return "llvm.x86.avx.rsqrt.ps.256";
367
+ }
368
+ }];
418
369
}
419
370
420
371
//----------------------------------------------------------------------------//
421
372
// AVX Dot
422
373
//----------------------------------------------------------------------------//
423
374
424
- def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
375
+ def DotOp : AVX_LowOp<"dot", [Pure,
376
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
377
+ SameOperandsAndResultType]> {
425
378
let summary = "Dot";
426
379
let description = [{
427
380
Computes the 4-way dot products of the lower and higher parts of the source
@@ -443,13 +396,16 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
443
396
VectorOfLengthAndType<[8], [F32]>:$b);
444
397
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
445
398
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
446
- }
447
399
448
- def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
449
- AllTypesMatch<["a", "b", "res"]>]> {
450
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
451
- VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
452
- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
400
+ let extraClassDefinition = [{
401
+ std::string $cppClass::getIntrinsicName() {
402
+ // Only one variant is supported right now - no extra mangling.
403
+ return "llvm.x86.avx.dp.ps.256";
404
+ }
405
+ }];
406
+ let extraClassDeclaration = [{
407
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
408
+ }];
453
409
}
454
410
455
411
#endif // X86VECTOR_OPS
0 commit comments