Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting is broken with LayoutTransform #4508

Closed
pyalex opened this issue Dec 11, 2019 · 4 comments
Closed

Broadcasting is broken with LayoutTransform #4508

pyalex opened this issue Dec 11, 2019 · 4 comments

Comments

@pyalex
Copy link

pyalex commented Dec 11, 2019

When using using optimization level >= 3 - LayoutTransform fails on converting C -> NCHW8c when input shape is (1,) and that breaks next operation which suppose to broadcast this input.

Code to reproduce. It works with opt_level < 3

target = 'llvm'
target_host = 'llvm'

input = relay.var('input', shape=(1, 500, 500, 64), dtype='float32')
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var('bias', shape=(64,), dtype='float32')

multiplier = relay.const([0.1])

x = relay.nn.conv2d(input, kernel, data_layout='NHWC', kernel_layout="HWIO", kernel_size=(3, 3))
x = relay.add(bias, x)
x = relay.nn.relu(x)

x = relay.multiply(multiplier, x)

fun = relay.Function([input, kernel, bias], x)

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(relay.Module({"main": fun}),
                                     target=target,
                                     target_host=target_host)

Compiled program with issue

fn (%input: Tensor[(1, 500, 500, 64), float32], %kernel: Tensor[(3, 3, 64, 64), float32], %bias: Tensor[(64), float32]) -> Tensor[(1, 498, 498, 64), float32] {
  %0 = expand_dims(meta[relay.Constant][0], axis=0, num_newaxis=3);
  %1 = layout_transform(%0, src_layout="NHWC", dst_layout="NCHW8c");
  %2 = expand_dims(%bias, axis=0, num_newaxis=3);
  %3 = layout_transform(%2, src_layout="NHWC", dst_layout="NCHW8c");
  %4 = layout_transform(%input, src_layout="NHWC", dst_layout="NCHW8c");
  %5 = layout_transform(%kernel, src_layout="HWIO", dst_layout="OIHW8i8o");
  %6 = nn.contrib_conv2d_NCHWc(%4, %5, channels=64, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c");
  %7 = add(%3, %6);
  %8 = nn.relu(%7);
  %9 = multiply(%1, %8) Incompatible broadcast type TensorType([1, 0, 1, 1, 8], float32) and TensorType([1, 8, 498, 498, 8], float32); ;
  layout_transform(%9, src_layout="NCHW8c", dst_layout="NHWC") an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr: 
; 
}

Full Stack Trace

an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr: 
Stack trace:
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  [bt] (1) 2   libtvm.dylib                        0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
  [bt] (2) 3   libtvm.dylib                        0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
  [bt] (3) 4   libtvm.dylib                        0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
  [bt] (4) 5   libtvm.dylib                        0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
  [bt] (5) 6   libtvm.dylib                        0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
  [bt] (6) 7   libtvm.dylib                        0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
  [bt] (7) 8   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (8) 9   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576

; ' should not has tab or newline.

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000110fdfa06 tvm::relay::transform::PassNode::operator()(tvm::relay::Module const&) const + 54
  [bt] (7) 8   libtvm.dylib                        0x0000000111052f86 tvm::relay::transform::SequentialNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 838
  [bt] (6) 7   libtvm.dylib                        0x000000011105338c tvm::relay::transform::Pass::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 156
  [bt] (5) 6   libtvm.dylib                        0x0000000111051b97 tvm::relay::transform::FunctionPassNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1607
  [bt] (4) 5   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
  [bt] (3) 4   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (2) 3   libtvm.dylib                        0x0000000111099648 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 136
  [bt] (1) 2   libtvm.dylib                        0x0000000111161769 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5433
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  [bt] (8) 9   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
  [bt] (7) 8   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (6) 7   libtvm.dylib                        0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
  [bt] (5) 6   libtvm.dylib                        0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
  [bt] (4) 5   libtvm.dylib                        0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
  [bt] (3) 4   libtvm.dylib                        0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
  [bt] (2) 3   libtvm.dylib                        0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
  [bt] (1) 2   libtvm.dylib                        0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "incubator-tvm/src/relay/ir/error.cc", line 132
@yzhliu
Copy link
Member

yzhliu commented Dec 11, 2019

The LayoutTransform operator is intended not to deal with such transform, I think we need to enhance FCorrectLayout of broadcast ops to insert extra expand_dims, @anijain2305 any thoughts?

@anijain2305
Copy link
Contributor

#4577

This is not just expand dims problem. The input shape is (1,) - if we take it from C to NCHW8c. Then , we first need to repeat to make the input shape atleast (8,).

Instead, a simpler way is to prevent the insertion of expand dims and layout transforms altogether for scalars. Scalars can be easily broadcasted even if the second tensor is transformed in layouts. The above PR does just that.

@anijain2305
Copy link
Contributor

@pyalex This is fixed. Can you please check and close the issue? Thanks!

@pyalex
Copy link
Author

pyalex commented Jan 20, 2020

Problem solved! Thanks

@pyalex pyalex closed this as completed Jan 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants