Skip to content

Commit 25bb616

Browse files
author
Tobias Gysi
committed
[mlir][linalg][python] Add attribute support to the YAML codegen.
Extend the yaml code generation to support the index attributes that https://reviews.llvm.org/D104711 added to the OpDSL. Differential Revision: https://reviews.llvm.org/D104712
1 parent adace79 commit 25bb616

File tree

6 files changed

+390
-100
lines changed

6 files changed

+390
-100
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 112 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ structured_op: !LinalgStructuredOpConfig
1313
args:
1414
- !LinalgOperandDefConfig
1515
name: A
16-
usage: input
17-
shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
16+
usage: InputOperand
1817
type_var: T1
18+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1919
- !LinalgOperandDefConfig
2020
name: B
21-
usage: input
22-
shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
21+
usage: InputOperand
2322
type_var: T2
23+
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
2424
- !LinalgOperandDefConfig
2525
name: C
26-
usage: output
27-
shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
26+
usage: OutputOperand
2827
type_var: U
28+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
2929
indexing_maps: !LinalgIndexingMapsConfig
3030
static_indexing_maps:
3131
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -75,19 +75,19 @@ structured_op: !LinalgStructuredOpConfig
7575
args:
7676
- !LinalgOperandDefConfig
7777
name: A
78-
usage: input
79-
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
78+
usage: InputOperand
8079
type_var: T1
80+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
8181
- !LinalgOperandDefConfig
8282
name: B
83-
usage: input
84-
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
83+
usage: InputOperand
8584
type_var: T2
85+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
8686
- !LinalgOperandDefConfig
8787
name: C
88-
usage: output
89-
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
88+
usage: OutputOperand
9089
type_var: U
90+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
9191
indexing_maps: !LinalgIndexingMapsConfig
9292
static_indexing_maps:
9393
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
@@ -138,19 +138,19 @@ structured_op: !LinalgStructuredOpConfig
138138
args:
139139
- !LinalgOperandDefConfig
140140
name: A
141-
usage: input
142-
shape: affine_map<()[s0, s1] -> (s0, s1)>
141+
usage: InputOperand
143142
type_var: T1
143+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
144144
- !LinalgOperandDefConfig
145145
name: y
146-
usage: input
147-
shape: affine_map<()[s0, s1] -> (s1)>
146+
usage: InputOperand
148147
type_var: T2
148+
shape_map: affine_map<()[s0, s1] -> (s1)>
149149
- !LinalgOperandDefConfig
150150
name: x
151-
usage: output
152-
shape: affine_map<()[s0, s1] -> (s0)>
151+
usage: OutputOperand
153152
type_var: U
153+
shape_map: affine_map<()[s0, s1] -> (s0)>
154154
indexing_maps: !LinalgIndexingMapsConfig
155155
static_indexing_maps:
156156
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -199,19 +199,19 @@ structured_op: !LinalgStructuredOpConfig
199199
args:
200200
- !LinalgOperandDefConfig
201201
name: y
202-
usage: input
203-
shape: affine_map<()[s0, s1] -> (s1)>
202+
usage: InputOperand
204203
type_var: T1
204+
shape_map: affine_map<()[s0, s1] -> (s1)>
205205
- !LinalgOperandDefConfig
206206
name: A
207-
usage: input
208-
shape: affine_map<()[s0, s1] -> (s1, s0)>
207+
usage: InputOperand
209208
type_var: T2
209+
shape_map: affine_map<()[s0, s1] -> (s1, s0)>
210210
- !LinalgOperandDefConfig
211211
name: x
212-
usage: output
213-
shape: affine_map<()[s0, s1] -> (s0)>
212+
usage: OutputOperand
214213
type_var: U
214+
shape_map: affine_map<()[s0, s1] -> (s0)>
215215
indexing_maps: !LinalgIndexingMapsConfig
216216
static_indexing_maps:
217217
- affine_map<(d0, d1)[s0, s1] -> (d1)>
@@ -260,19 +260,19 @@ structured_op: !LinalgStructuredOpConfig
260260
args:
261261
- !LinalgOperandDefConfig
262262
name: A
263-
usage: input
264-
shape: affine_map<()[s0] -> (s0)>
263+
usage: InputOperand
265264
type_var: T1
265+
shape_map: affine_map<()[s0] -> (s0)>
266266
- !LinalgOperandDefConfig
267267
name: B
268-
usage: input
269-
shape: affine_map<()[s0] -> (s0)>
268+
usage: InputOperand
270269
type_var: T2
270+
shape_map: affine_map<()[s0] -> (s0)>
271271
- !LinalgOperandDefConfig
272272
name: C
273-
usage: output
274-
shape: affine_map<()[s0] -> ()>
273+
usage: OutputOperand
275274
type_var: U
275+
shape_map: affine_map<()[s0] -> ()>
276276
indexing_maps: !LinalgIndexingMapsConfig
277277
static_indexing_maps:
278278
- affine_map<(d0)[s0] -> (d0)>
@@ -306,6 +306,83 @@ structured_op: !LinalgStructuredOpConfig
306306
- !ScalarExpression
307307
scalar_arg: B
308308
--- !LinalgOpConfig
309+
metadata: !LinalgOpMetadata
310+
name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
311+
cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
312+
doc: A depth-wise 2-D convolution operation.
313+
structured_op: !LinalgStructuredOpConfig
314+
args:
315+
- !LinalgOperandDefConfig
316+
name: I
317+
usage: InputOperand
318+
type_var: T1
319+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
320+
(s0, s6, s7, s3)>
321+
- !LinalgOperandDefConfig
322+
name: K
323+
usage: InputOperand
324+
type_var: T2
325+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
326+
(s4, s5, s3)>
327+
- !LinalgOperandDefConfig
328+
name: O
329+
usage: OutputOperand
330+
type_var: U
331+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
332+
(s0, s1, s2, s3)>
333+
- !LinalgOperandDefConfig
334+
name: strides
335+
usage: IndexAttribute
336+
type_var: I64
337+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
338+
-> (s8, s9)>
339+
- !LinalgOperandDefConfig
340+
name: dilations
341+
usage: IndexAttribute
342+
type_var: I64
343+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
344+
-> (s10, s11)>
345+
indexing_maps: !LinalgIndexingMapsConfig
346+
static_indexing_maps:
347+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
348+
s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)>
349+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
350+
s10, s11] -> (d4, d5, d3)>
351+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
352+
s10, s11] -> (d0, d1, d2, d3)>
353+
iterator_types:
354+
- parallel
355+
- parallel
356+
- parallel
357+
- parallel
358+
- reduction
359+
- reduction
360+
assignments:
361+
- !ScalarAssign
362+
arg: O
363+
value: !ScalarExpression
364+
scalar_apply:
365+
fn_name: add
366+
operands:
367+
- !ScalarExpression
368+
scalar_arg: O
369+
- !ScalarExpression
370+
scalar_apply:
371+
fn_name: mul
372+
operands:
373+
- !ScalarExpression
374+
symbolic_cast:
375+
type_var: U
376+
operands:
377+
- !ScalarExpression
378+
scalar_arg: I
379+
- !ScalarExpression
380+
symbolic_cast:
381+
type_var: U
382+
operands:
383+
- !ScalarExpression
384+
scalar_arg: K
385+
--- !LinalgOpConfig
309386
metadata: !LinalgOpMetadata
310387
name: fill_rng_2d
311388
cpp_class_name: FillRng2DOp
@@ -323,21 +400,21 @@ structured_op: !LinalgStructuredOpConfig
323400
args:
324401
- !LinalgOperandDefConfig
325402
name: min
326-
usage: input
403+
usage: InputOperand
327404
type_var: F64
328405
- !LinalgOperandDefConfig
329406
name: max
330-
usage: input
407+
usage: InputOperand
331408
type_var: F64
332409
- !LinalgOperandDefConfig
333410
name: seed
334-
usage: input
411+
usage: InputOperand
335412
type_var: I32
336413
- !LinalgOperandDefConfig
337414
name: O
338-
usage: output
339-
shape: affine_map<()[s0, s1] -> (s0, s1)>
415+
usage: OutputOperand
340416
type_var: T
417+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
341418
indexing_maps: !LinalgIndexingMapsConfig
342419
static_indexing_maps:
343420
- affine_map<(d0, d1)[s0, s1] -> ()>

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,36 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
3030

3131
// -----
3232

33+
func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
34+
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
35+
ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
36+
return %0: tensor<1x2x4x1xf32>
37+
}
38+
39+
// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32
40+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
41+
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32
42+
// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32
43+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
44+
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
45+
46+
// -----
47+
48+
func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
49+
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
50+
ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
51+
return %0: tensor<1x2x4x1xi32>
52+
}
53+
54+
// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32
55+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
56+
// CHECK-NEXT: %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32
57+
// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32
58+
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
59+
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
60+
61+
// -----
62+
3363
func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
3464
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
3565
return %0: tensor<16x32xf32>

mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ structured_op: !LinalgStructuredOpConfig
2121
args:
2222
- !LinalgOperandDefConfig
2323
name: O
24-
usage: output
25-
shape: affine_map<()[s0, s1] -> (s0, s1)>
24+
usage: OutputOperand
2625
type_var: T
26+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
2727
indexing_maps: !LinalgIndexingMapsConfig
2828
static_indexing_maps:
2929
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -86,12 +86,13 @@ structured_op: !LinalgStructuredOpConfig
8686

8787
# @linalg_structured_op
8888
# def test2(I=TensorDef(T, S.M, S.N),
89-
# O=TensorDef(T, S.M, S.N, output=True)):
89+
# O=TensorDef(T, S.M, S.N, output=True),
90+
# strides=AttributeDef(S.SM, S.SN)):
9091
# """Title.
9192

9293
# Detailed description.
9394
# """
94-
# O[D.m, D.n] = I[D.n, D.m]
95+
# O[D.m, D.n] = I[D.n * S.SM, D.m * S.SN]
9596

9697
--- !LinalgOpConfig
9798
metadata: !LinalgOpMetadata
@@ -103,49 +104,67 @@ metadata: !LinalgOpMetadata
103104
Detailed description.
104105
structured_op: !LinalgStructuredOpConfig
105106
args:
106-
- !LinalgOperandDefConfig
107-
name: value
108-
usage: input
109-
type_var: T
110107
- !LinalgOperandDefConfig
111108
name: I
112-
usage: input
113-
shape: affine_map<()[s0, s1] -> (s1, s0)>
109+
usage: InputOperand
114110
type_var: T
111+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
115112
- !LinalgOperandDefConfig
116113
name: O
117-
usage: output
118-
shape: affine_map<()[s0, s1] -> (s0, s1)>
114+
usage: OutputOperand
119115
type_var: T
116+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
117+
- !LinalgOperandDefConfig
118+
name: strides
119+
usage: IndexAttribute
120+
type_var: I64
121+
attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
120122
indexing_maps: !LinalgIndexingMapsConfig
121123
static_indexing_maps:
122-
- affine_map<(d0, d1)[s0, s1] -> ()>
123-
- affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
124-
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
124+
- affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
125+
- affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>
125126
iterator_types:
126127
- parallel
127128
- parallel
128129
assignments:
129130
- !ScalarAssign
130131
arg: O
131132
value: !ScalarExpression
132-
scalar_apply:
133-
fn_name: add
134-
operands:
135-
- !ScalarExpression
136-
scalar_arg: value
137-
- !ScalarExpression
138-
scalar_arg: I
133+
scalar_arg: I
139134

140-
# IMPL-LABEL: Test2Op::iterator_types()
141-
# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
135+
# ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2"
136+
137+
# ODS: let arguments =
138+
# ODS-NEXT: Variadic<AnyType>:$inputs,
139+
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
140+
# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides
141+
142+
# ODS: "Attribute":$strides
143+
# ODS: $_state.addAttribute("strides", strides);
144+
145+
# ODS: bool hasDynamicIndexingMaps();
146+
# ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes();
147+
148+
# IMPL: getSymbolBindings(Test2Op self)
149+
# IMPL: cst2 = self.strides().getValue<int64_t>({ 0 });
150+
# IMPL-NEXT: getAffineConstantExpr(cst2, context)
151+
# IMPL: cst3 = self.strides().getValue<int64_t>({ 1 });
152+
# IMPL-NEXT: getAffineConstantExpr(cst3, context)
142153

143154
# IMPL: Test2Op::indexing_maps()
144-
# IMPL: "affine_map<(d0, d1)[s0, s1] -> ()>"
145-
# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
146-
# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
155+
# IMPL: = getSymbolBindings(*this);
156+
# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>"
157+
# IMPL: "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>"
158+
159+
# IMPL: Test2Op::getNumRegionArgs() { return 2; }
160+
161+
# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; }
162+
# IMPL: Test2Op::verifyIndexingMapRequiredAttributes()
163+
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
164+
# IMPL: "missing indexing map required attribute 'strides'"
147165

148166
# IMPL: void Test2Op::regionBuilder(
149-
# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
167+
# IMPL-NEXT: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
168+
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
150169

151-
# IMPL: = helper.applyfn__add(block.getArgument(0), block.getArgument(1));
170+
# IMPL: yields.push_back(block.getArgument(0));

0 commit comments

Comments
 (0)