-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework QJIT class into distinct compilation stages #531
Conversation
de6fe29
to
2ce0add
Compare
Also switch to only returning the "canonical" jaxpr output without implicit result values.
This cache can replace the existing caching mechanism for static arguments, and also caches different function definitions based on PyTreeDefs of dynamic arguments. Different function definitions based on the ShapedArrays of dynamic is not (yet) provided, as that would incurr additional overhead for type promotion checks.
The compilation_pipelines module is renamed as the name is outdated. The QJIT class is rewritten into more easily reuseable portions and simplified where possible. The updated class makes use of the new compilation cache to access previous versions of compiled functions.
6139eae
to
1c555ba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love those changes, thanks! I have left some comments
728ae41
to
3805b29
Compare
5917cb4
to
42791fe
Compare
42791fe
to
b0a9338
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #531 +/- ##
==========================================
- Coverage 99.55% 99.42% -0.13%
==========================================
Files 52 52
Lines 8457 8510 +53
Branches 559 568 +9
==========================================
+ Hits 8419 8461 +42
- Misses 20 27 +7
- Partials 18 22 +4 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great clean up 💯
Also fixes wrong attribute access uncovered by test.
cf11747
to
414de95
Compare
This is part 2 of a refactor started in #529. The QJIT class is reworked into 5 distinct compilation stages: - pre-compilation (like autograph) - capture (jaxpr generation) - ir-generation (mlir generation) - compilation (llvm and binary code generation - cannot be split up since this happens in the compiler driver) - execution The class is also streamlined by using a new compilation cache to handle previously compiled functions and signature lookups. One point of contention might be the results produced by the split of the `trace_to_mlir` function, which have been simplified and need to be double checked against #520. EDIT: c71c322 should address this concern [sc-57014] closes #268 closes #520
This is part 2 of a refactor started in #529.
The QJIT class is reworked into 5 distinct compilation stages:
The class is also streamlined by using a new compilation cache to handle previously compiled functions and signature lookups.
One point of contention might be the results produced by the split of the
trace_to_mlir
function, which have been simplified and need to be double checked against #520. EDIT: c71c322 should address this concern[sc-57014]
closes #268
closes #520