Skip to content

Commit

Permalink
[Bugfix][TVMScript] Handle R.match_cast as last binding in if/else
Browse files Browse the repository at this point in the history
Prior to this commit, using `R.match_cast` as the last binding would
produce a segfault, as `var_binding->value` was used instead of
`match_cast->value`.  In addition, because the last binding of each
branch was removed, any changes to the struct info resulting from the
match cast were silently discarded.

This commit updates the TVMScript parsing of if/else statements to
remove the segfault and maintain the struct info changes produced by
the `R.match_cast`.
  • Loading branch information
Lunderberg committed Feb 14, 2024
1 parent bb2adbf commit af56392
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
4 changes: 3 additions & 1 deletion src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ void ElseFrameNode::ExitWithScope() {
IfFrame frame = FindIfFrame("R.Else");
frame->else_expr = output;
CHECK(frame->var_name == var_name)
<< "This last binding of both branches must have the same variable.";
<< "This last binding of both branches must provide the same variable. "
<< "However, the R.Then branch provides variable " << frame->var_name
<< ", while the R.Else branch provides variable " << var_name;
}

TVM_REGISTER_NODE_TYPE(FunctionFrameNode);
Expand Down
52 changes: 36 additions & 16 deletions src/script/ir_builder/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ inline BlockFrame CheckBlockFrameExistAndUnended() {
inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) {
// Step 0. Check frame type
std::string method;
std::string output_var_suffix;
if (frame->IsInstance<ThenFrameNode>()) {
method = "R.Then";
output_var_suffix = "_then";
} else if (frame->IsInstance<ElseFrameNode>()) {
method = "R.Else";
output_var_suffix = "_else";
} else {
ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey();
}
Expand All @@ -84,29 +87,46 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String
const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back();
CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty.";

// Step 2. Collect body from the last binding.
// Step 2. Update the last binding of each branch. While we could
// use the last bound value of each branch as a SeqExpr body, the
// Normalizer would pull it back out into a `gv#` binding anyways.
// Generating a new variable in each branch provides a more readable
// variable name.

tvm::relax::Binding last_binding = last_block->bindings.back();
CHECK(!last_binding->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";

*var_name = last_binding->var->name_hint();

// Step 3. Re-collect binding blocks to replace the last binding.
Array<tvm::relax::BindingBlock> new_blocks(frame->binding_blocks.begin(),
frame->binding_blocks.end() - 1);
Array<tvm::relax::Binding> last_block_bindings(last_block->bindings.begin(),
last_block->bindings.end() - 1);

tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix,
GetStructInfo(last_binding->var));
tvm::relax::Expr body;
const tvm::relax::Binding& last_binding = last_block->bindings.back();
if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
CHECK(!var_binding->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";

if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();
var_binding && var_binding->value->IsInstance<tvm::relax::VarNode>()) {
body = var_binding->value;
*var_name = var_binding->var->name_hint();
} else if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
last_block_bindings.push_back(last_binding =
tvm::relax::VarBinding(new_var, var_binding->value));
body = new_var;
} else if (const auto* match_cast = last_binding.as<tvm::relax::MatchCastNode>()) {
CHECK(!match_cast->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";
body = var_binding->value;
*var_name = match_cast->var->name_hint();
last_block_bindings.push_back(
tvm::relax::MatchCast(new_var, match_cast->value, match_cast->struct_info));
body = new_var;
} else {
ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey();
}

// Step 3. Re-collect binding blocks to remove the last binding.
Array<tvm::relax::BindingBlock> new_blocks(frame->binding_blocks.begin(),
frame->binding_blocks.end() - 1);
Array<tvm::relax::Binding> last_block_bindings(last_block->bindings.begin(),
last_block->bindings.end() - 1);
new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings));
new_blocks.push_back(last_block->IsInstance<tvm::relax::DataflowBlockNode>()
? tvm::relax::DataflowBlock(last_block_bindings)
: tvm::relax::BindingBlock(last_block_bindings));

return tvm::relax::SeqExpr(new_blocks, body);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,47 @@ def check_call(call, op, args):
check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var])


def test_if_branch_with_match_cast():
"""The last branch of a relax::If node may be a MatchCast
This is a regression test. In previous implementations, using
R.match_cast as the last binding would cause a segfault while
parsing.
"""

@R.function
def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")):
if is_bfloat16:
A = R.match_cast(A, R.Tensor([16, 16], "bfloat16"))
B = A.astype("float16")
else:
B = R.match_cast(A, R.Tensor([16, 16], "float16"))
return B

A, is_bfloat16 = func.params
(block,) = func.body.blocks
(B_binding,) = block.bindings

B_var = B_binding.var
assert isinstance(B_var, relax.Var)
assert B_var.name_hint == "B"

if_then_else = B_binding.value
assert isinstance(if_then_else, relax.If)
assert isinstance(if_then_else.true_branch, relax.SeqExpr)
assert isinstance(if_then_else.false_branch, relax.SeqExpr)

else_branch = if_then_else.false_branch
(else_block,) = else_branch.blocks

assert isinstance(else_block.bindings[-1], relax.MatchCast)

# If the `R.match_cast` were removed, the function would infer the
# return value as `R.Tensor([16,16])`, with an unknown dtype.
# With the `R.match_cast` retained, the output dtype is known.
tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16"))


def test_if_inside_dataflow():
with pytest.raises(tvm.error.DiagnosticError):

Expand Down

0 comments on commit af56392

Please sign in to comment.