Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework QJIT class into distinct compilation stages #531

Merged
merged 12 commits into from Feb 23, 2024
Merged

Conversation

dime10
Copy link
Collaborator

@dime10 dime10 commented Feb 21, 2024

This is part 2 of a refactor started in #529.

The QJIT class is reworked into 5 distinct compilation stages:

  • pre-compilation (like autograph)
  • capture (jaxpr generation)
  • ir-generation (mlir generation)
  • compilation (llvm and binary code generation - cannot be split up since this happens in the compiler driver)
  • execution

The class is also streamlined by using a new compilation cache to handle previously compiled functions and signature lookups.

One point of contention might be the results produced by the split of the trace_to_mlir function, which have been simplified and need to be double checked against #520. EDIT: c71c322 should address this concern

[sc-57014]

closes #268
closes #520

@dime10 dime10 added the frontend Pull requests that update the frontend label Feb 21, 2024
Base automatically changed from frontend-refactor-1 to main February 21, 2024 22:32
Also switch to only returning the "canonical" jaxpr output without
implicit result values.
This cache can replace the existing caching mechanism for static
arguments, and also caches different function definitions based on
PyTreeDefs of dynamic arguments.

Different function definitions based on the ShapedArrays of dynamic
is not (yet) provided, as that would incurr additional overhead for
type promotion checks.
The compilation_pipelines module is renamed as the name is outdated.
The QJIT class is rewritten into more easily reuseable portions and
simplified where possible. The updated class makes use of the new
compilation cache to access previous versions of compiled functions.
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love those changes, thanks! I have left some comments

frontend/test/pytest/test_debug.py Show resolved Hide resolved
frontend/catalyst/compiled_functions.py Show resolved Hide resolved
frontend/catalyst/compiled_functions.py Show resolved Hide resolved
frontend/catalyst/compiled_functions.py Show resolved Hide resolved
frontend/catalyst/jit.py Show resolved Hide resolved
@dime10 dime10 added the ci:build-wheels Run the wheel building workflows on this Pull Request label Feb 23, 2024
Copy link

codecov bot commented Feb 23, 2024

Codecov Report

Attention: Patch coverage is 94.23077% with 12 lines in your changes are missing coverage. Please review.

Project coverage is 99.42%. Comparing base (47fd3fc) to head (ba30a81).

❗ Current head ba30a81 differs from pull request most recent head cf11747. Consider uploading reports for the commit cf11747 to get more accurate results

Files Patch % Lines
frontend/catalyst/debug.py 36.36% 6 Missing and 1 partial ⚠️
frontend/catalyst/jit.py 96.11% 1 Missing and 3 partials ⚠️
frontend/catalyst/jax_tracer.py 91.66% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #531      +/-   ##
==========================================
- Coverage   99.55%   99.42%   -0.13%     
==========================================
  Files          52       52              
  Lines        8457     8510      +53     
  Branches      559      568       +9     
==========================================
+ Hits         8419     8461      +42     
- Misses         20       27       +7     
- Partials       18       22       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean up 💯

frontend/catalyst/jax_extras/tracing.py Outdated Show resolved Hide resolved
Also fixes wrong attribute access uncovered by test.
@dime10 dime10 merged commit 44506f1 into main Feb 23, 2024
53 of 55 checks passed
@dime10 dime10 deleted the frontend-refactor-2 branch February 23, 2024 23:40
rauletorresc pushed a commit that referenced this pull request Feb 26, 2024
This is part 2 of a refactor started in #529. 

The QJIT class is reworked into 5 distinct compilation stages:
- pre-compilation (like autograph)
- capture (jaxpr generation)
- ir-generation (mlir generation)
- compilation (llvm and binary code generation - cannot be split up
since this happens in the compiler driver)
- execution

The class is also streamlined by using a new compilation cache to handle
previously compiled functions and signature lookups.

One point of contention might be the results produced by the split of
the `trace_to_mlir` function, which have been simplified and need to be
double checked against #520. EDIT:
c71c322
should address this concern

[sc-57014]

closes #268 
closes #520
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci:build-wheels Run the wheel building workflows on this Pull Request frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Store correct JAX representation in QJIT object QJIT state machine raises recompilation warning unexpectedly
3 participants