Skip to content

[Unity] Allow duplicated parameters in the same call arguments in FuseOps#14097

Closed
masahi wants to merge 5 commits intoapache:unityfrom
masahi:pattern-partition-dedup-issue
Closed

[Unity] Allow duplicated parameters in the same call arguments in FuseOps#14097
masahi wants to merge 5 commits intoapache:unityfrom
masahi:pattern-partition-dedup-issue

Conversation

@masahi
Copy link
Member

@masahi masahi commented Feb 23, 2023

Currently, when FuseOps and FuseOpsByPattern see a call node whose arguments have duplicated parameters, e.g.

with R.dataflow():
    out = R.add(data, data)
    R.output(out)

they create a grouped function whose signature has parameters deduplicated:

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_add(data)
            R.output(gv)
        return gv

    @R.function
    def fused_relax_add(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Composite": "tensorrt.add", "Primitive": 1})
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(data, data)
            R.output(gv)
        return gv

This is fine if the grouped function is codegen-ed automatically by TVM, but for BYOC use cases (FuseOpsByPattern) this is problematic. If a user creates a pattern

add_pat = is_op("relax.add")(wildcard(), wildcard())

he / she expects to create a function with two parameters, that does addition. Indeed, if I replace the RHS with an expression other than data, such function is created. The fact that the same expression happens to be used for both LHS and RHS shouldn't matter when creating a grouped function. So the current behavior doesn't match user's intention. Moreover, a backend needs to able to handle multiple cases for code-generating the same add function with different signatures (one or two arguments).

This PR modifies the behavior of FuseOps and FuseOpsByPattern, so that duplicated parameters in the same call arguments are allowed to appear as distinct parameters in a grouped function. If the same expression is used in different bindings, like below, it is deduplicated (the current behavior is preserved).

with R.dataflow():
    lv = R.call_tir(add, (x, p0), out_sinfo=R.Tensor((1, 16, 64, 64), dtype="float32"))
    lv1 = R.call_tir(add, (x, p1), out_sinfo=R.Tensor((1, 16, 64, 64), dtype="float32"))
    lv2 = R.call_tir(add, (x, p2), out_sinfo=R.Tensor((1, 16, 64, 64), dtype="float32"))

@sunggg @Hzfengsy @vinx13 @yelite

@tvm-bot
Copy link
Collaborator

tvm-bot commented Feb 23, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@masahi masahi force-pushed the pattern-partition-dedup-issue branch from db6728d to 46b500c Compare February 23, 2023 07:41
@tqchen tqchen requested a review from Hzfengsy February 23, 2023 14:30
Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix, @masahi!
Just one little nit :)

@masahi masahi force-pushed the pattern-partition-dedup-issue branch from 46b500c to 0959e57 Compare February 23, 2023 23:13
} else {
const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
ICHECK(tuple_item != nullptr);
CheckDefAndUpdateParam(tuple_item->tuple);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Hzfengsy, I'm not sure what this code path is for. I can pass all tests without this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used for the TupleGetItem node. I'm surprised that it is not covered by the tests

// If the expression has already served as an argument, no need to create another one for it.
if (std::find(arguments_.begin(), arguments_.end(), expr) != arguments_.end()) {
return;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Hzfengsy, I decoupled deduplication check and the actual param update in this function, to allow duplicated expressions in arguments_ and params_. As the comment at L486 says, it is now a responsibility of the caller to do deduplication beforehand if dedup is needed.

See also the discussion in https://github.com/apache/tvm/pull/14097/files#r1116338447

@tqchen
Copy link
Member

tqchen commented Feb 24, 2023

Please run the following command to rebase the changes

git rebase --onto upstream/unity upstream/unity-rebase-backup-2023-02-24

@masahi masahi force-pushed the pattern-partition-dedup-issue branch from cc6e8c3 to 5eb4c85 Compare February 24, 2023 22:16
Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @masahi for the PR. I acknowledge that duplicated parameters are useful in some cases (e.g. BYOC). However, it's a bit tricky if we only use tvm codegen, since:

  1. we lose part of the information at the callee side, which means it would be harder (but still possible) to optimize;
  2. Duplication may influence code readability.

Here, my question is: can we add a parameter for FuseOps pass to decide if we generate duplicated params? we can turn it off for tvm-codegen and turn on when BYOC

@vinx13
Copy link
Member

vinx13 commented Feb 25, 2023

I think a question we need to think about is whether we need duplicated args in BYOC cases. There are design tradeoffs here:

A1: Keep duplicated args so that the composite function signature is consistent with the pattern (add(wildcard(), wildcard())). In most cases this is easier for codegen. It can directly use arg0 and arg1 of the function as args for add. This is also the case for cutlass byoc, where we use arg0, arg1 for x and w of conv2d, respectively.

A2: Allow args being deduplicated. This approach indeed will generate duplicating composite function (with one arg or two args for add) for the codegen.

On the other hand, We also find some cases that the composite function doesn’t have args in the consistent argument order as the pattern / or ‘so-called’ anchor op. For example, we may have composite function like:

def fused_permute_matmul(y, x):
  lv = permute_dims(y)
  return matmul(x, lv)

where the arg order of matmul is different from the composite function.
In fact, the op fuser generally doesn’t guarantee the argument order of the fused function, its order is simply the order of the appearance of bindings in the original function. And as a result, for the above case we have to find the matmul, and then for each arg of matmul we need to find the corresponding function argument.

If the order of args in generally not assumed and we need to find such correspondence (anchor op args <-> function args) anyways, then whether the function args are deduplicated probably is less important.

@tqchen
Copy link
Member

tqchen commented Feb 25, 2023

In this case, if we can independently fix the BYOC(so that the pattern and generated func do not depend on ordering and duplication of the argument, that would leads to a more generalized solution which might help to resolve cases like @vinx13 mentioned

@yelite
Copy link
Contributor

yelite commented Feb 25, 2023

In #14128 I created a generalized way (https://github.com/apache/tvm/pull/14128/files#diff-5abe4ef979258cbcc4927f057f939304de762c63c445137ba570689854c8f862R553-R569) to map args of the fused function to named parameters of the offloaded operation (for example, map arg1 -> "lhs", arg0 -> "rhs", arg2 -> "bias" for matmul). This requires each pattern to be accompanied by a map from string to DFPattern, like {"lhs": pattern_lhs, "rhs": pattern_rhs}, where each value of the map is a child pattern of the pattern that's used to fuse function.

If we think the order of args and duplication of args is a common concern for all BYOC backends, but not for other use case of FuseOps, we should be able to extract that logic from #14128 and combine with the code in this PR to create a standalone pass for reordering and duplicating args.

@masahi
Copy link
Member Author

masahi commented Feb 27, 2023

In this case, if we can independently fix the BYOC(so that the pattern and generated func do not depend on ordering and duplication of the argument, that would leads to a more generalized solution which might help to resolve cases like @vinx13 mentioned

Thanks for the discussion, I think we shouldn't push all responsibilities to codegen and call it "a more general solution". At least for the cases where there is a one-to-one mapping between a composite function and an op, clearly there is one "right" function signature. So each backend shouldn't need to worry about add or other trivial functions possibly having multiple signatures. For more complicated, "truly fused" composite functions that are meant to be consumed by library-based BYOC (dnnl, cutlass), I think it is reasonable to ask codegen to handle the signature problem.

I liked @yelite suggestion of leaving the current op fusion impl as is, and instead adding a standalone pass that would identify such trivial composite functions (the one that merely wraps an op, e.g. tensorrt.add) and fix up their signature if necessary. Later we may extend that pass to support the fused_permute_matmul case - convert the signature into "more natural" one based on the anchor op call site.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants