Skip to content

Commit

Permalink
[Relay][AlterLayout] Broadcast with scalar shape
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Dec 26, 2019
1 parent 73dda6b commit 7509f85
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
return *channels;
}

/*!
* \brief Is single value tensor (scalar).
* \param expr The expr.
* \return True if single value tensor.
*/
inline bool IsScalar(const Expr& expr) {
if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
for (auto dim_index_expr : tensor_type->shape) {
if (auto dim_index = dim_index_expr.as<IntImm>()) {
if (dim_index->value != 1) {
return false;
}
} else {
return false;
}
}
} else {
return false;
}
return true;
}

/*!
* \brief Create a Constant with a scalar
*
Expand Down
5 changes: 5 additions & 0 deletions src/relay/pass/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef {
Expr input_expr = raw;
Layout new_src_layout = src_layout;
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even if
// the other operand has a transformed layout.
if (IsScalar(input_expr)) {
return raw;
}
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
new_src_layout = src_layout.ExpandPrimal(dst_layout);
input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,70 @@ def expected():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_broadcast_scalar_op():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')

y = relay.nn.conv2d(x, kernel,
data_layout='NHWC',
kernel_layout="HWIO",
kernel_size=(3, 3))
y = relay.add(bias, y)
y = relay.nn.relu(y)

y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')

b = relay.expand_dims(bias, axis=0, num_newaxis=3)
b = relay.layout_transform(b, "NHWC", "NCHW16c")

y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, kernel,
data_layout='NCHW16c',
kernel_layout="HWIO",
kernel_size=(3, 3))

y = relay.add(b, y)
y = relay.nn.relu(y)

y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.layout_transform(y, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
Expand Down Expand Up @@ -980,6 +1044,7 @@ def expected():
test_alter_layout_dual_path()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
Expand Down

0 comments on commit 7509f85

Please sign in to comment.