Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Deformable conv2d #2908

Merged
merged 6 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,67 @@ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
}
};


/*! \brief Attributes for DeformableConv2D operator */
struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int deformable_groups;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(deformable_groups).set_default(1)
.describe("Controls the connections between inputs and offsets."
"Input channels are partitioned into multiple deformable groups. Offsets"
"are shared across input channels in the same deformable group.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
2 changes: 2 additions & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def extract_from_program(func, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}

topi_funcs = []
Expand Down Expand Up @@ -126,6 +127,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}

topi_funcs = []
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self):
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}

self.topi_to_schedule = {
Expand All @@ -78,6 +79,7 @@ def __init__(self):
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}

self._register_tracing()
Expand Down Expand Up @@ -172,6 +174,15 @@ def _topi_nn_dense(*args, **kwargs):
return s, [data, weight, bias, C]
return s, [data, weight, C]

@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, Offset, W = args[:3]
C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_deformable_conv2d_nchw([C])
return s, [A, Offset, W, C]

def reset(self, wanted_topi_funcs):
"""Reset task collections

Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,25 @@ def _mx_smooth_l1(inputs, attrs):
_op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))


def _mx_deformable_convolution(inputs, attrs):
new_attrs = {}
assert attrs.get_bool("no_bias")
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
new_attrs["strides"] = attrs.get_int_tuple("stride")
new_attrs["padding"] = attrs.get_int_tuple("pad")
new_attrs["dilation"] = attrs.get_int_tuple("dilate")
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["deformable_groups"] = attrs.get_int("num_deformable_group", 1)
new_attrs["groups"] = attrs.get_int("num_group", 1)
assert attrs.get_str("layout", "NCHW") == "NCHW", "Deformable conv2d only supports NCHW layout"
use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.deformable_conv2d(inputs[0], inputs[1], inputs[2], **new_attrs)
if use_bias:
assert len(inputs) == 4
res = _op.nn.bias_add(res, inputs[3])
return res


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -748,6 +767,7 @@ def _mx_smooth_l1(inputs, attrs):
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,26 @@ def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.deformable_conv2d")
def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
"""Compute definition of deformable_conv2d"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
deformable_groups = attrs.deformable_groups
groups = attrs.groups
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
with target:
out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
dilation, deformable_groups, groups, out_dtype)
return [out]

@reg.register_schedule("nn.deformable_conv2d")
def schedule_deformable_conv2d(attrs, outs, target):
"""Schedule definition of deformable_conv2d"""
with target:
return topi.generic.schedule_deformable_conv2d_nchw(outs)

reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
73 changes: 73 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,76 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
"""
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
weight, convolution_algorithm, out_dtype)


def deformable_conv2d(data,
offset,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
deformable_groups=1,
groups=1,
channels=None,
kernel_size=None,
data_layout='NCHW',
kernel_layout='OIHW',
out_layout='',
out_dtype=''):
r""" Deformable 2d convolution.

The deformable convolution operation is described in https://arxiv.org/abs/1703.06211

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

offset : tvm.relay.Expr
The offset expressions.

weight : tvm.relay.Expr
The weight expressions.

strides : tuple of int, optional
The strides of convoltution.

padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.

dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.

deformable_groups : int, optional
Number of deformable groups.

groups : int, optional
Number of groups for grouped convolution.

channels : int, optional
Number of output channels of this convolution.

kernel_size : tuple of int, optional
The spatial of the convolution kernel.

data_layout : str, optional
Layout of the input.

kernel_layout : str, optional
Layout of the weight.

out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout

out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.

Returns
-------
result : tvm.relay.Expr
The computed result.

"""
return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
deformable_groups, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)