Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Implement an if-elif-else result type unification #333

Merged
merged 20 commits into from
Nov 2, 2023

Conversation

sergei-mironov
Copy link
Contributor

@sergei-mironov sergei-mironov commented Oct 25, 2023

Return values of conditional functions no longer need to be of exactly the same type. Catalyst would apply type promotion to branch return values if their types don't match.

[sc-41332]

@sergei-mironov sergei-mironov marked this pull request as ready for review October 26, 2023 11:25
@codecov
Copy link

codecov bot commented Oct 30, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (5d818b7) 99.62% compared to head (6215cc7) 99.62%.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #333   +/-   ##
=======================================
  Coverage   99.62%   99.62%           
=======================================
  Files          42       42           
  Lines        7454     7480   +26     
  Branches      439      449   +10     
=======================================
+ Hits         7426     7452   +26     
  Misses         14       14           
  Partials       14       14           
Files Coverage Δ
frontend/catalyst/jax_tracer.py 99.60% <100.00%> (+0.04%) ⬆️
frontend/catalyst/pennylane_extensions.py 99.64% <100.00%> (-0.01%) ⬇️
frontend/catalyst/utils/jax_extras.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, great work @grwlf 💯

doc/changelog.md Outdated Show resolved Hide resolved
frontend/catalyst/jax_tracer.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_tracer.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_tracer.py Outdated Show resolved Hide resolved
Comment on lines 154 to 156
jaxprs (list of ClosedJaxpr): Source JAXPR expressions. The expression results must have
matching pytree-shapes and numpy array shapes but dtypes might
be different.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Should we assert this?

Copy link
Contributor Author

@sergei-mironov sergei-mironov Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call the assertions in the _promote_jaxpr_types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these assertions:

assert len(types) > 0, "Expected one or more set of types"
assert all(len(t) == len(types[0]) for t in types), "Expected matching number of arguments"

where types = [j.out_avals for j in jaxprs],

only check that the number of results from each jaxpr is the same? That is it doesn't check whether the pytrees match, nor whether array shapes match, which the docstring states as a requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, correct! I will add those.

Copy link
Contributor Author

@sergei-mironov sergei-mironov Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the shape assertion, but the pytree one is not that easy. We don't need the Pytrees when we do quantum tracing so there is no access to them. Adding them just to be able to assert here doesn't look feasible. We do check it elsewhere though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that means we don't need to mention pytrees in the comment so I have updated it.

Copy link
Contributor

@dime10 dime10 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's important that the Pytrees match, otherwise we wouldn't know which Pytree to restore as the result of the conditional operation.

This would make for a good test case actually, where they match and where they don't.

Copy link
Contributor

@dime10 dime10 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I misspelled. I meant "I do think it's important that the PyTrees match". I think we should add those test cases @grwlf

For example:

            @cond(True)
            def cond_fn():
                return (1, 1)

            @cond_fn.otherwise
            def cond_fn():
                return [2, 2]

should fail, because we don't know which PyTree to restore as the result of the operation, even though the number of arguments and dtypes match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed here #351

frontend/catalyst/jax_tracer.py Outdated Show resolved Hide resolved
Sergei Mironov and others added 6 commits November 1, 2023 14:36
@sergei-mironov sergei-mironov merged commit b3787a3 into main Nov 2, 2023
21 checks passed
@sergei-mironov sergei-mironov deleted the if-elif-else-resutl-type-unification branch November 2, 2023 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants