Skip to content

Commit

Permalink
[Relay][AlterLayout] Broadcast with scalar shape
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Dec 23, 2019
1 parent f9bc748 commit 09702a0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
Expr input_expr = raw;
Layout new_src_layout = src_layout;
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even if
// the other operand has a transformed layout.
if (IsScalar(input_expr)) {
return raw;
}
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
new_src_layout = src_layout.ExpandPrimal(dst_layout);
input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
Expand Down
19 changes: 19 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,25 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
return *channels;
}

/*!
* \brief Is single value tensor (scalar).
* \param expr The expr.
* \return True if single value tensor.
*/
inline bool IsScalar(const Expr& expr) {
if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
if (tensor_type->shape.size() == 1) {
auto size = tensor_type->shape[0];
if (auto index_expr = size.as<IntImm>()) {
if (index_expr->value == 1) {
return true;
}
}
}
}
return false;
}

/*!
* \brief Create a Constant with a scalar
*
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,52 @@ def expected():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_broadcast_scalar_op():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
b = relay.var("b", shape=(1,))
weight = relay.var("weight")
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.add(y, b)
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
b = relay.var("b", shape=(1,))
w = relay.var("weight")

y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y = relay.add(y, b)

y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
Expand Down Expand Up @@ -980,6 +1026,7 @@ def expected():
test_alter_layout_dual_path()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
Expand Down

0 comments on commit 09702a0

Please sign in to comment.