-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Implement an if-elif-else result type unification #333
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #333 +/- ##
=======================================
Coverage 99.62% 99.62%
=======================================
Files 42 42
Lines 7454 7480 +26
Branches 439 449 +10
=======================================
+ Hits 7426 7452 +26
Misses 14 14
Partials 14 14
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, great work @grwlf 💯
frontend/catalyst/jax_tracer.py
Outdated
jaxprs (list of ClosedJaxpr): Source JAXPR expressions. The expression results must have | ||
matching pytree-shapes and numpy array shapes but dtypes might | ||
be different. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Should we assert this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We call the assertions in the _promote_jaxpr_types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't these assertions:
assert len(types) > 0, "Expected one or more set of types"
assert all(len(t) == len(types[0]) for t in types), "Expected matching number of arguments"
where types = [j.out_avals for j in jaxprs]
,
only check that the number of results from each jaxpr is the same? That is it doesn't check whether the pytrees match, nor whether array shapes match, which the docstring states as a requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, correct! I will add those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the shape assertion, but the pytree one is not that easy. We don't need the Pytrees when we do quantum tracing so there is no access to them. Adding them just to be able to assert here doesn't look feasible. We do check it elsewhere though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that means we don't need to mention pytrees in the comment so I have updated it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's important that the Pytrees match, otherwise we wouldn't know which Pytree to restore as the result of the conditional operation.
This would make for a good test case actually, where they match and where they don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, I misspelled. I meant "I do think it's important that the PyTrees match". I think we should add those test cases @grwlf
For example:
@cond(True)
def cond_fn():
return (1, 1)
@cond_fn.otherwise
def cond_fn():
return [2, 2]
should fail, because we don't know which PyTree to restore as the result of the operation, even though the number of arguments and dtypes match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed here #351
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Return values of conditional functions no longer need to be of exactly the same type. Catalyst would apply type promotion to branch return values if their types don't match.
[sc-41332]