Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Oct 25, 2019
1 parent 61f3fef commit f775105
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_const_tuple(in_tuple):
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elem = tvm.ir_pass.Simplify(elem)
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
ret.append(elem)
else:
Expand Down
10 changes: 1 addition & 9 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,11 @@ def broadcast_shape_func(attrs, inputs, out_ndims):
"""
return [_broadcast_shape_func(*inputs, out_ndims[0])]

@script
def _elemwise_shape_func(data_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
out[i] = data_shape[i]

return out

def elemwise_shape_func(attrs, inputs, _):
"""
Shape function for elemwise op.
"""
return [_elemwise_shape_func(inputs[0])]
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)

Expand Down
6 changes: 3 additions & 3 deletions src/lang/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
std::unordered_set<std::string> symbolic_var_set;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (orig_shape.as<ir::Any>()) {
symbolic_var_set.insert(orig_axis->var->name_hint);
symbolic_var_set.insert(i);
}
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
Expand Down Expand Up @@ -321,7 +321,7 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
if (symbolic_var_set.count(axis->var->name_hint)) {
if (symbolic_var_set.count(i)) {
result.push_back(ir::Any::make());
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
Expand Down

0 comments on commit f775105

Please sign in to comment.