-
Notifications
You must be signed in to change notification settings - Fork 3
Fix has_done_something and fixpoint logic
#545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix has_done_something and fixpoint logic
#545
Conversation
Since result was always recomputed from the join of result and result_, as soon as result_.has_done_something was set to True in the first loop iteration, result.has_done_something would never become false anymore.
This is important for fixed-point iterations because in that case, the fixed-point loop would just iterate until it reaches max iteration.
Before this fix, the following code would produce the following IR:
rule = rewrite.Fixpoint(
rewrite.Walk(
rewrite.Chain(
ilist.rewrite.HintLen(),
rewrite.ConstantFold(),
rewrite.DeadCodeElimination(),
)
)
)
@basic
def len_func(xs: ilist.IList[int, Literal[3]]):
return len(xs)
@basic
def len_func3(xs: ilist.IList[int, Any]):
return len(xs)
rule.rewrite(len_func.code)
len_func.print()
func.func len_func(!py.IList[!py.int, Literal(3,int)]) -> !Any {
^0(%len_func_self, %xs):
│ %0 = py.constant.constant 3 : !py.int
│ %1 = py.constant.constant 3 : !py.int
│ %2 = py.constant.constant 3 : !py.int
│ %3 = py.constant.constant 3 : !py.int
│ %4 = py.constant.constant 3 : !py.int
│ %5 = py.constant.constant 3 : !py.int
│ %6 = py.constant.constant 3 : !py.int
│ %7 = py.constant.constant 3 : !py.int
│ %8 = py.constant.constant 3 : !py.int
│ %9 = py.constant.constant 3 : !py.int
│ %10 = py.constant.constant 3 : !py.int
│ %11 = py.constant.constant 3 : !py.int
│ %12 = py.constant.constant 3 : !py.int
│ %13 = py.constant.constant 3 : !py.int
│ %14 = py.constant.constant 3 : !py.int
│ %15 = py.constant.constant 3 : !py.int
│ %16 = py.constant.constant 3 : !py.int
│ %17 = py.constant.constant 3 : !py.int
│ %18 = py.constant.constant 3 : !py.int
│ %19 = py.constant.constant 3 : !py.int
│ %20 = py.constant.constant 3 : !py.int
│ %21 = py.constant.constant 3 : !py.int
│ %22 = py.constant.constant 3 : !py.int
│ %23 = py.constant.constant 3 : !py.int
│ %24 = py.constant.constant 3 : !py.int
│ %25 = py.constant.constant 3 : !py.int
│ %26 = py.constant.constant 3 : !py.int
│ %27 = py.constant.constant 3 : !py.int
│ %28 = py.constant.constant 3 : !py.int
│ %29 = py.constant.constant 3 : !py.int
│ %30 = py.constant.constant 3 : !py.int
│ %31 = py.len.len(value=%xs : !py.IList[!py.int, Literal(3,int)]) : !py.int
│ func.return %0
} // func.func len_func
If the type-in for a pass was to always do something, this would lead to infinite fixpoint loops (or loops would reach max iterations)
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
☂️ Python Coverage
Overall Coverage
New FilesNo new covered files... Modified Files
|
e979178 to
1c791a5
Compare
| for _ in range(max_iter): | ||
| result_ = self.unsafe_run(mt) | ||
| result = result_.join(result) | ||
| if not result.has_done_something: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _ is important. If the pass does something in the first iteration, result.has_done_something will always be true for all iterations.
This PR fixes several issues related to fixpoint rewrites:
1. A subtle bug in `Pass.fixpoint` lead to rewrites being applied for
`max_iter` iterations, even if they had already converged.
2. Several compiler passes were incorrectly returning
`has_done_something=True`, even if they had not done anything, leading
to maxed out fixpoint loops.
3. The CSE rewrite would prematurely exit, leading to more fixpoint
iterations.
I checked that all tests in the kirin test suite lead to fixpoint
rewrites that converge before reaching `max_iter`. It's still possible
that there are more rewrite passes that need to be fixed. We should
consider raising an error or at least a warning if a fixpoint pass does
not converge, as this will immediately point us to any such issues in
the future.
**The default `max_iter` is set to 32. With this fix, many passes now
iterate only once or twice instead of 32 times. So in many cases there
is a 32x speedup. Some passes also contain a fixpoint within a fixpoint.
They can now become 1000x faster.**
I noticed these issues when running the following script
```python
@qasm2.extended # O(n) scaling
def main_fun():
qreg = qasm2.qreg(1)
for _ in range(n):
qasm2.h(qreg[0])
QASM2().emit(main_fun) # ~O(n^1.5)
```
This script runs really slow for large `n`. But even more important,
recording the duration of compilation and QASM2 emission, I found
non-linear scaling ~ O(n^1.5). Any compiler pass must be linear in the
program size (or at worst n log n).
<img width="331" height="250" alt="image"
src="https://github.com/user-attachments/assets/c914db82-e2c6-4821-8a0e-b00ab19b432c"
/>
With the fixes in this PR, the execution time of this script is
significantly faster and the computational complexity is back to the
expected O(n):
<img width="331" height="245" alt="image"
src="https://github.com/user-attachments/assets/35480c95-29b5-4117-b796-c4cc891d1cb0"
/>
These issues also affect the runtime of @cduck's flair-to-squin pipeline
- but I think there are additional slowdowns that I'm still looking
into.
has_done_something and fixpoint logichas_done_something and fixpoint logic
This PR fixes several issues related to fixpoint rewrites:
Pass.fixpointlead to rewrites being applied formax_iteriterations, even if they had already converged.has_done_something=True, even if they had not done anything, leading to maxed out fixpoint loops.I checked that all tests in the kirin test suite lead to fixpoint rewrites that converge before reaching
max_iter. It's still possible that there are more rewrite passes that need to be fixed. We should consider raising an error or at least a warning if a fixpoint pass does not converge, as this will immediately point us to any such issues in the future.The default
max_iteris set to 32. With this fix, many passes now iterate only once or twice instead of 32 times. So in many cases there is a 32x speedup. Some passes also contain a fixpoint within a fixpoint. They can now become 1000x faster.I noticed these issues when running the following script
This script runs really slow for large
n. But even more important, recording the duration of compilation and QASM2 emission, I found non-linear scaling ~ O(n^1.5). Any compiler pass must be linear in the program size (or at worst n log n).With the fixes in this PR, the execution time of this script is significantly faster and the computational complexity is back to the expected O(n):

These issues also affect the runtime of @cduck's flair-to-squin pipeline - but I think there are additional slowdowns that I'm still looking into.