Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay] PrimitiveInliner doesn't correctly handle match expressions that have more than one use #4758

Closed
abergeron opened this issue Jan 21, 2020 · 2 comments

Comments

@abergeron
Copy link
Contributor

I have this sample program:

import tvm
from tvm import relay
from tvm.relay import adt

ctx = tvm.ndarray.context('cpu', 0)

mod = relay.Module({})

union_type = relay.GlobalTypeVar("u")
c0_type = relay.ty.TupleType([relay.ty.scalar_type('int32'), union_type()])
c0 = adt.Constructor("c0", [c0_type], union_type)
c1 = adt.Constructor("c1", [relay.ty.TupleType([])], union_type)

mod[union_type] = adt.TypeData(union_type, [], [c0, c1])

gv = relay.GlobalVar("fn")
p = relay.var('p', union_type())
v = relay.Var('v')
cond = adt.Match(p, [adt.Clause(adt.PatternConstructor(c0, [adt.PatternWildcard\
()]), relay.const(True)),
                     adt.Clause(adt.PatternWildcard(), relay.const(False))])

mm = adt.Match(p, [adt.Clause(adt.PatternConstructor(c0, [adt.PatternVar(v)]), \
v)], complete=False)

fn = relay.Function(
    [p],
    relay.If(
        cond,
        relay.const(0),
        relay.add(relay.TupleGetItem(mm, 0),
                  relay.Call(gv, [relay.TupleGetItem(mm, 1)]))
    ),
    ret_type=relay.ty.scalar_type('int32')
)

mod[gv] = fn

q = relay.var("q", union_type())
mod["main"] = relay.Function([q], relay.Call(gv, [q]))

print(str(mod))

vm = relay.create_executor(mod=mod, ctx=ctx, target='llvm', kind="vm")

The module looks like this before compiling:

v0.0.4
type u {
  c0((int32, u[])),
  c1(()),
}

def @main(%q: u[]) -> int32 {
  @fn(%q) /* ty=int32 */
}

def @fn(%p: u[]) -> int32 {
  %0 = match (%p) {
    c0(_) => True /* ty=bool */,
    _ => False /* ty=bool */,
  };
  if (%0) {
    0 /* ty=int32 */
  } else {
    %1 = match? (%p) {
      c0(%v: (int32, u[])) => %v,
    };
    %2 = %1.0;
    %3 = %1.1;
    %4 = @fn(%3) /* ty=int32 */;
    add(%2, %4) /* ty=int32 */
  }
}

And I get this error during compilation with kind="vm", but it works correctly for kind="debug":

Traceback (most recent call last):

  File "tst.py", line 42, in <module>
    vm = relay.create_executor(mod=mod, ctx=ctx, target='llvm', kind="vm")

  File "/home/anakha/ext/tvm/python/tvm/relay/build_module.py", line 411, in create_executor
    return VMExecutor(mod, ctx, target)

  File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 540, in __init__
    self.executable = compile(mod, target)

  File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 399, in compile
    compiler.lower(mod, target, target_host)

  File "/home/anakha/ext/tvm/python/tvm/relay/backend/vm.py", line 455, in lower
    self._lower(mod, target, target_host)

  File "tvm/_ffi/_cython/./function.pxi", line 304, in tvm._ffi._cy3.core.FunctionBase.__call__

  File "tvm/_ffi/_cython/./function.pxi", line 239, in tvm._ffi._cy3.core.FuncCall

  File "tvm/_ffi/_cython/./function.pxi", line 228, in tvm._ffi._cy3.core.FuncCall3

  File "tvm/_ffi/_cython/./base.pxi", line 157, in tvm._ffi._cy3.core.CALL

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/anakha/ext/tvm/build/libtvm.so(+0x2097b7d) [0x7f0bebd40b7d]
  [bt] (7) /home/anakha/ext/tvm/build/libtvm.so(+0x2097d9c) [0x7f0bebd40d9c]
  [bt] (6) /home/anakha/ext/tvm/build/libtvm.so(+0x2097f08) [0x7f0bebd40f08]
  [bt] (5) /home/anakha/ext/tvm/build/libtvm.so(+0x2097579) [0x7f0bebd40579]
  [bt] (4) /home/anakha/ext/tvm/build/libtvm.so(tvm::relay::vm::PrimitiveInliner::Inline()+0x276) [0x7f0bebd41a1c]
  [bt] (3) /home/anakha/ext/tvm/build/libtvm.so(tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool)+0xcb) [0x7f0beb3daca3]
  [bt] (2) /home/anakha/ext/tvm/build/libtvm.so(tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, tvm::relay::Function)+0x5b) [0x7f0beb3da662]
  [bt] (1) /home/anakha/ext/tvm/build/libtvm.so(tvm::relay::DeDup(tvm::RelayExpr const&)+0x106) [0x7f0bebb460cd]
  [bt] (0) /home/anakha/ext/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4a) [0x7f0beb36f884]
  File "/home/anakha/ext/tvm/src/relay/pass/de_duplicate.cc", line 109
TVMError: Check failed: WellFormed(e): v0.0.4
fn (%p: u[]) -> int32 {
  %0 = match (%p) {
    c0(_) => True /* ty=bool */,
    _ => False /* ty=bool */,
  };
  if (%0) {
    0 /* ty=int32 */
  } else {
    %1 = match? (%p) {
      c0(%v: (int32, u[])) => %v,
    };
    %2 = %1.0;
    %3 = match? (%p) {
      c0(%v: (int32, u[])) => %v,
    };
    %4 = %3.1;
    %5 = @fn(%4);
    add(%2, %5)
  }
}

The only thing I can notice is that something duplicated the match in the else branch and this probably makes the code not well formed. I traced the WellFormed Error and it comes from %v. I may be wrong about this.

@abergeron
Copy link
Contributor Author

After some experimentation, if I comment out InlinePrimitives at line 921 in compiler.cc (https://github.com/apache/incubator-tvm/blob/master/src/relay/backend/vm/compiler.cc#L921), then the compilation works. It seems this is the line that introduces the weird change.

If I add an override like this:

  Expr VisitExpr_(const MatchNode* m) {
    std::vector<Clause> clauses;
    for (const Clause& p : m->clauses) {
      clauses.push_back(VisitClause(p));
    }
    return GetRef<Expr>(m);
  }

to the PrimitiveInliner class (in https://github.com/apache/incubator-tvm/blob/master/src/relay/backend/vm/inline_primitives.cc#L54), it appears to make the compilation work (at least in my limited example), but this may have some other consequences that I am not aware of (I suspect this means that no transformations will be applied to the body of match clauses, which might be bad).

I'll keep working on that for a bit trying to find an acceptable solution, but I will gladly take any help/hints that I can get.

@abergeron abergeron changed the title [relay] ADT match becomes not well formed during VM optimization [relay] PrimitiveInliner doesn't correctly handle match expression that have more than one reference Jan 27, 2020
@abergeron abergeron changed the title [relay] PrimitiveInliner doesn't correctly handle match expression that have more than one reference [relay] PrimitiveInliner doesn't correctly handle match expressions that have more than one reference Jan 27, 2020
@abergeron abergeron changed the title [relay] PrimitiveInliner doesn't correctly handle match expressions that have more than one reference [relay] PrimitiveInliner doesn't correctly handle match expressions that have more than one use Jan 27, 2020
@abergeron
Copy link
Contributor Author

Fixed by #4783

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant