Skip to content

Commit bc07634

Browse files
gpetters94Mahesh Ravishankar
authored andcommitted
Adding a named op for grouped convolutions
1 parent f9710d1 commit bc07634

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,105 @@ structured_op: !LinalgStructuredOpConfig
16441644
- !ScalarExpression
16451645
scalar_arg: K
16461646
--- !LinalgOpConfig
1647+
metadata: !LinalgOpMetadata
1648+
name: conv_2d_ngchw_fgchw
1649+
cpp_class_name: Conv2DNgchwFgchwOp
1650+
doc: |-
1651+
Performs 2-D convolution.
1652+
1653+
Layout:
1654+
* Input: NGCHW.
1655+
* Kernel: FGCHW.
1656+
1657+
Numeric casting is performed on the operands to the inner multiply, promoting
1658+
them to the same data type as the accumulator/output.
1659+
implements:
1660+
- LinalgConvolutionOpInterface
1661+
structured_op: !LinalgStructuredOpConfig
1662+
args:
1663+
- !LinalgOperandDefConfig
1664+
name: I
1665+
kind: input_tensor
1666+
type_var: T1
1667+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> (s0,
1668+
s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)>
1669+
- !LinalgOperandDefConfig
1670+
name: K
1671+
kind: input_tensor
1672+
type_var: T2
1673+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> (s10,
1674+
s1, s11, s4, s8)>
1675+
- !LinalgOperandDefConfig
1676+
name: O
1677+
kind: output_tensor
1678+
type_var: U
1679+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> (s0,
1680+
s1, s10, s2, s6)>
1681+
- !LinalgOperandDefConfig
1682+
name: strides
1683+
kind: index_attr
1684+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
1685+
(s3, s7)>
1686+
default_indices:
1687+
- 1
1688+
- 1
1689+
- !LinalgOperandDefConfig
1690+
name: dilations
1691+
kind: index_attr
1692+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
1693+
(s5, s9)>
1694+
default_indices:
1695+
- 1
1696+
- 1
1697+
indexing_maps: !LinalgIndexingMapsConfig
1698+
static_indexing_maps:
1699+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1700+
s9, s10, s11] -> (d0, d1, d5, d3 * s3 + d6 * s5, d4 * s7 + d7 * s9)>
1701+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1702+
s9, s10, s11] -> (d2, d1, d5, d6, d7)>
1703+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1704+
s9, s10, s11] -> (d0, d1, d2, d3, d4)>
1705+
iterator_types:
1706+
- parallel
1707+
- parallel
1708+
- parallel
1709+
- parallel
1710+
- parallel
1711+
- reduction
1712+
- reduction
1713+
- reduction
1714+
assignments:
1715+
- !ScalarAssign
1716+
arg: O
1717+
value: !ScalarExpression
1718+
scalar_fn:
1719+
kind: binary
1720+
fn_name: add
1721+
operands:
1722+
- !ScalarExpression
1723+
scalar_arg: O
1724+
- !ScalarExpression
1725+
scalar_fn:
1726+
kind: binary
1727+
fn_name: mul
1728+
operands:
1729+
- !ScalarExpression
1730+
scalar_fn:
1731+
kind: type
1732+
fn_name: cast_signed
1733+
type_var: U
1734+
operands:
1735+
- !ScalarExpression
1736+
scalar_arg: I
1737+
- !ScalarExpression
1738+
scalar_fn:
1739+
kind: type
1740+
fn_name: cast_signed
1741+
type_var: U
1742+
operands:
1743+
- !ScalarExpression
1744+
scalar_arg: K
1745+
--- !LinalgOpConfig
16471746
metadata: !LinalgOpMetadata
16481747
name: conv_3d_ndhwc_dhwcf
16491748
cpp_class_name: Conv3DNdhwcDhwcfOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,27 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
366366
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
367367
D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
368368

369+
@linalg_structured_op
370+
def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH,
371+
S.OW * S.SW + S.KW * S.DW),
372+
K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
373+
O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
374+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
375+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
376+
"""Performs 2-D grouped convolution.
377+
378+
Layout:
379+
* Input: NGCHW.
380+
* Kernel: FGCHW.
381+
382+
Numeric casting is performed on the operands to the inner multiply, promoting
383+
them to the same data type as the accumulator/output.
384+
"""
385+
implements(ConvolutionOpInterface)
386+
domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
387+
O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
388+
U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
389+
D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
369390

370391
@linalg_structured_op
371392
def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,

0 commit comments

Comments
 (0)