Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Common Subexpression Elimination for TIR #9482

Merged
merged 56 commits into from
Feb 10, 2022

Conversation

FranckQC
Copy link
Contributor

@FranckQC FranckQC commented Nov 10, 2021

Hi everyone,

We would like to upstream some work that we did at Qualcomm.
The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.

Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable.

If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite this function for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality.

The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing.

The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.

For a greater flexibility in the future, there is a strong distinction already in place between :

  • Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen).
  • Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen.

The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.

When dealing with a candidate computation, there are three cases that can happen:

  • 1 - Rare case where a variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
    -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.

  • 2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
    -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let "new_var_i = currentComputation in result".

  • 3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
    -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
    Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.

Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher, which will in turn call the appropriate handlers. The only specific task of the overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loops, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.

For more details, this new pass has been presented at the TVM Con 2021 ( https://www.tvmcon.org/events/common-subexpression-elimination-for-tir/ ) and thanks to the organizers of the conference the video is now available on Youtube here : https://www.youtube.com/watch?v=Iuio-oJSOv0 .

Please do not hesitate if you have any question.
Thank you.

…pache#703)

The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.

Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable.

If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality.

The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing.

The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.

For a greater flexibility in the future, there is a strong distinction already in place between :

    - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen).
    - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen.

The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.

When dealing with a candidate computation, there are three cases that can happen:

    1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
    -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.

    2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
    -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result.

    3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
    -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
    Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.

Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.
@masahi
Copy link
Member

masahi commented Nov 10, 2021

Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!

What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU sort kernel, I need to make log2(N) GPU kernel calls from the host to sort the input bottom up. In principle, log2(N) needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE log2(N) expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with log2(N) compute like this (note a lot of calls to call_spirv_pure_glsl450 which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f

Is this PR going to solve my problem?

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick round of review.

the changes of vta-hw is not necessary

src/tir/analysis/check_contains.cc Outdated Show resolved Hide resolved
@FranckQC
Copy link
Contributor Author

FranckQC commented Nov 15, 2021

Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this!

What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU sort kernel, I need to make log2(N) GPU kernel calls from the host to sort the input bottom up. In principle, log2(N) needs to computed once by the host and pass to the GPU kernel, but since we cannot CSE log2(N) expression that appears both in the host and GPU kernel, right now the GPU sort kernel is littered with log2(N) compute like this (note a lot of calls to call_spirv_pure_glsl450 which is totally unnecessary if we had TIR-level CSE) https://gist.github.com/masahi/7a755ef67009e1a836e3212c53cf496f

Is this PR going to solve my problem?

Hi @masahi
Thanks a lot for the kind words! I'm happy to read that this new pass might be useful to you.
In principle, every redundant subterms that are eligible for being commoned out (i.e, which does not contain function calls, etc) will be commoned out. There are also a few other minor restrictions which are due to some specifics of TVM, but these are rare.

Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital.

I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions.

However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible.
For instance :
Assume you have the term (f(42) + f(42)) + (xy + xy) that appear somewhere.
It is not eligible in its entirety, as it contains function calls.
But its subterm (x*y) is eligible, so this subpart will be commoned out.
In short, the CSE pass, as implemented, always try to common out all the redundant computations that are illegible, and it does it by looking from bigger subterms to smaller subterms.

Does that answers your question?
Please do not hesitate to tell me if you need help for trying the pass.

Kind regards.

@FranckQC
Copy link
Contributor Author

A quick round of review.

the changes of vta-hw is not necessary

Thanks a lot for the review!
I added the missing empty new lines at the end of each new file.

I'll rollback the pointer to the sub-module in another commit.

Many thanks again!

@FranckQC
Copy link
Contributor Author

The pointer to the submodule vta-hw has been rolled-back to its previous state.

@wrongtest-intellif
Copy link
Contributor

Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like

A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
...

We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.

@wrongtest-intellif
Copy link
Contributor

wrongtest-intellif commented Nov 18, 2021

I have another three questions about the pass.

  1. Since function calls are non-eligible, it seems to prevent many common non-stateful builtin computations from being optimized, such as tir.floordiv, tir.shift_right, tir.likely, tir.log and etc. TVM already have op annotation like PURE:
    https://github.com/apache/tvm/blob/main/include/tvm/tir/op_attr_types.h#L60-L67
    what if to add an annotation type like DETERMINISTIC to mark the function return same output on same input. It should be
    safe to eliminate common calls if the call op is both pure and deterministic.

  2. Is the let binding semantic in tir explicitly what we want? Is the value expr is specified to be evaluated at let scope entry? otherwise there may be no reuse effect if let-binding allow the value expr get lazily expanded in let body.

  3. Is this possible for some complex (but common) expr to be lifted out of control flow scope? eg

    if cond0:
        tir.evaluate(complex_e)  # slow path
    else:
       if cond2: 
            tir.evaluate(complex_e)  # slow path
       else:
            ...  # general fast path

    after CSE maybe comes to

    x = tir.Var("int32")
    with tir.let(x, complex_e):  # always evaluate slow path
        if cond0:
            tir.evaluate(x) 
        else:
            if cond2: 
                tir.evaluate(x)
            else:
                ... # general fast path

@wrongtest-intellif
Copy link
Contributor

BTW current split host device machinary seems work quite well with common expression bindings!
A test script is as below, it shows common expression is evaluated at host function and passed as a kernel param to device function.

import tvm
from tvm.script import tir as T

@T.prim_func
def func(a: T.handle, b: T.handle, n: T.int32) -> None:
    threadIdx_x = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [256], dtype="int32")
    B = T.match_buffer(b, [256], dtype="int32")
    common_expr = T.var("int32")
    # for common_expr in range(n // 8, n // 8 + 1):
    with T.let(common_expr, n // 8):
        for i in T.serial(0, common_expr):
            T.launch_thread(threadIdx_x, 8)
            T.store(B.data, i * 8 + threadIdx_x, common_expr + T.load("int32", A.data, i * 8 + threadIdx_x), True)

mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "main", "target": tvm.target.Target("cuda")}))(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
print(mod.script())


# script for result mod 
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(a: T.handle, b: T.handle, n: T.int32) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "target": None})
        A = T.match_buffer(a, [256], dtype="int32")
        B = T.match_buffer(b, [256], dtype="int32")
        # body
        for common_expr in T.serial(n // 8, n // 8 + 1):
            for i in T.serial(0, common_expr):
                T.evaluate(T.tvm_call_packed("main_kernel0", B.data, A.data, common_expr, i, 8, dtype="int32"))

    @T.prim_func
    def main_kernel0(B_1: T.Ptr[global T.int32], A_1: T.Ptr[global T.int32], common_expr: T.int32, i: T.int32) -> None:
        # function attr dict
        T.func_attr({"target": cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, "tir.noalias": 1, "global_symbol": "main_kernel0", "tir.device_thread_axis": [T.iter_var(threadIdx_x, [0:8], "ThreadIndex", "threadIdx.x")], "tir.is_global_func": 1, "calling_conv": 2})
        # var definition
        threadIdx_x = T.env_thread("threadIdx.x")
        # body
        T.launch_thread(threadIdx_x, 8)
        T.store(B_1, i * 8 + threadIdx_x, common_expr + T.load("int32", A_1, i * 8 + threadIdx_x), True)

@masahi
Copy link
Member

masahi commented Nov 18, 2021

Does that answers your question?

@FranckQC Thanks, yes absolutely. I can work on extending this pass to support my use case, after this is merged.

@FranckQC
Copy link
Contributor Author

FranckQC commented Nov 18, 2021

Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like

A[i*256 + j*16 + 0] = B[i*256 + j*16 + 4096]
A[i*256 + j*16 + 1] = B[i*256 + j*16 + 4097]
...

We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level.

Hi @wrongtest
Thank you, I'm very glad to know that the pass might be useful to you too!

Yes, these kind of redundancies should definitely be commoned out by this new CSE pass.
For the specific example that you have given, I expect the result to be something like :

Let cse_var_1 = i*256 + j*16 in
    A[cse_var_1 + 0] = B[cse_var_1 + 4096];
    A[cse_var_1 + 1] = B[cse_var_1 + 4097];

Do not hesitate to try the pass out, and to let me know if it does what we hope. I'd be happy to help of course if that's needed.

Best regards.

@FranckQC
Copy link
Contributor Author

FranckQC commented Feb 9, 2022

Current status:
image

The only remaining failures are the VTA ones (both for python3: i386 and for unittest: CPU), which should now work with the trick kindly given by @masahi.
I did not realize that vta.build() actually overwrites the disabled pass of the config. I'll try to fix that just after as it's a nasty side effect. But for now, let's verify that everything is green now.
So, restarting the CI...

…{} for the disabled passes instead of a list with [] for some reason
@masahi
Copy link
Member

masahi commented Feb 9, 2022

Please have a look again @Hzfengsy @wrongtest

I'm merging this this week unless there are other comments @tqchen @junrushao1994 @vinx13

Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the clear and complete comments, very much appreciate that.

Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.

The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.

In your experience is this mostly firing on the affine index sub-expressions, or do you see cse over actual data sub-expressions?

Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?

@masahi
Copy link
Member

masahi commented Feb 9, 2022

Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?

Sure! In addition to #9482 (comment), I think we can refactor our Substitute function to be a special case of the subsutiuter with a predicate added in this PR.

@FranckQC
Copy link
Contributor Author

FranckQC commented Feb 9, 2022

Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?

Sure! In addition to #9482 (comment), I think we can refactor our Substitute function to be a special case of the subsutiuter with a predicate added in this PR.

I would be very happy to look into that as soon as this is merged. Hopefully the current run of the CI should be the last one!

@FranckQC
Copy link
Contributor Author

FranckQC commented Feb 9, 2022

Thanks so much for the clear and complete comments, very much appreciate that.

Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.

The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.

In your experience is this mostly firing on the affine index sub-expressions, or do you see cse over actual data sub-expressions?

Thank you so much for the compliment, I really appreciate it. It makes me happy to know that the code is easy to read!

If I recall well I saw quite a lot of indices (mostly from loop unrolling), just like what @wrongtest had here #9482 (comment).

Also some indices due to lowering of memory accesses, for instance:
C[x,y] = C[x,y] + A[x,k] * B[y,k]
which can lowered (2D to 1D) to:
C[x128+y] = C[x128+y] + A[x128+k]B[y128+k]
which gives the opportunity to create first:
cse_var_1 = x
128+y
and then in cascade:
cse_var_2 = x*128

And I also recall a lot of random commoning, like:
result-CSE-realProgram-better

I'll post more if I can find more notes about more interesting commonings performed in test files and models.

@FranckQC
Copy link
Contributor Author

FranckQC commented Feb 10, 2022

Sorry I forgot to answer to the other parts of your message @mbs-octoml . Many thanks for it by the way!

Thanks so much for the clear and complete comments, very much appreciate that.
Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.

Yes I didn't know about TVMScript before. When writing the tests, I initially stared at some other test files and got inspiration from them. Unfortunately the ones I've been looking at might not have been the most up-to-date way of doing things, sorry for that! :-(
I guess we can still improve the tests later on, but they still accomplish their main function for now, which is the most important I think, even though they are probably a little bit more verbose than they would be using TVMScript, where I could just directly write the expected TIR code instead of unfolding the TIR code that the CSE pass has produced.

The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.

Yes, I agree that this duplication is a little bit unfortunate. @masahi did pointed it out here. I was also a little bit annoyed with it at the beginning. So I tried to factorize it out a few times, including an attempt described in my answer here. But all my attempt ended up with something much too complicated for what we would gain.

In fact, we just happen to want to do almost exactly the same treatment for an expression and for a statement from an algorithmic point of view, but from a data-type point of view, quite a things are still different type-wise. That's a pretty rare situation. In the end, I decided to not force things, and to leave it like that.
And the positive point of having the VisitStmt() and VisitExpr() not factorized by some weird magic is that we can still easily customize the algorithm if at some point we need to, between expressions and statements (I can't think of any reason we would want that, but still! :) )

Many thanks again!

@masahi masahi dismissed Hzfengsy’s stale review February 10, 2022 12:46

Will be addressed in a follow up

@masahi masahi merged commit 09f7be2 into apache:main Feb 10, 2022
@masahi
Copy link
Member

masahi commented Feb 10, 2022

Thanks @FranckQC @Hzfengsy @wrongtest @mbs-octoml @jroesch this is merged!!

@masahi
Copy link
Member

masahi commented Feb 10, 2022

Follow-up items in #10211

ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* Initial implementation of Common Subexpression Elimination for TIR (apache#703)

The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.

Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable.

If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality.

The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing.

The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.

For a greater flexibility in the future, there is a strong distinction already in place between :

    - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen).
    - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen.

The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.

When dealing with a candidate computation, there are three cases that can happen:

    1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
    -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.

    2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
    -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result.

    3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
    -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
    Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.

Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.

* Added empty newline at the end of every new file

* Rolled-back the pointer to the submodule vta-hw

* Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too.

* Spelling and comment

* Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too.

* Revert "Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too."

This reverts commit c4138d9.

* Fixed reference used for no reason instead of normal variable.

* Added comment explaning why we do not need the union/intersection over N tables at the moment (because we would only use it for N=3)

* Did most of the changes suggested by upstream

* Continued to work on the remarks given on the public repo.

* Final remarks addressed, small formatting things, and fixing things reported by the linter

* Last linter fix.

* Fixing newline

* Adding newline missing.

* Minor commit for style fo conform with clang-format

* Removed trailing space at end of line

* And more minor style changes

* Fixing style of the python test files

* And one more for style in python tests!

* This linter is very annoying to force the style of indentation in a comment, in a test file. It makes it harder to read in this case! And that incitates people to not write comments

* Deactivate the CSE pass for the lowering tests as it would otherwise do some commoning, and improve the way the CSE recurse + test added for cascade commonings

* Fixing new lint offenses

* Removing debug statement

* Restore other test file to its previous state

* One more for the linter...

* Linter again, this time for the new test...

* again

* again...

* Deactivating the CSE pass for another lowering test as it does some commoning

* Disabling the CSE for the a test for GPU too

* Trying to fix a VTA test by disabling the CSE pass for it, as it probably does some commoning

* Complying with the linter

* Restarting the CI 1/2

* Restarting the CI 2/2

* Restarting CI 1/2

* Restarting CI 2/2

* Slightly reduce size of large pretty printer test, copied from apache@ae98f9e

* Trying to resolve the problems on the weird tests

* Linter.

* Restarting CI which has skipped the MacOS build for no reason 1/2

* Restarting CI which has skipped the MacOS build for no reason 2/2

* Commented buggy tests

* Linter...

* Restore the VTA tests, and use trick kindly given  by Masa to disable the CSE pass for the VTA tests, as vta.build() overwrittes the config

* New fix, which this time does not break the doc (VTA uses a set with {} for the disabled passes instead of a list with [] for some reason

* More VTA fixes

* vta tutorial fix

Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants