-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend, MLIR] Support indexing of the dynamically shaped arrays #411
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #411 +/- ##
=======================================
Coverage 99.56% 99.56%
=======================================
Files 43 43
Lines 7643 7646 +3
Branches 512 512
=======================================
+ Hits 7610 7613 +3
Misses 17 17
Partials 16 16 ☔ View full report in Codecov by Sentry. |
Hi @rmoyard , could you please review the MLIR-related parts of this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice patch 👍
2cf433c
to
75a6d4c
Compare
@josh146 I meant this PR, when I asked about CodeFactor problems. The 'ComplexMethod' checks lack diagnostics so I suggest disabling these. |
with Patcher((jax._src.interpreters.partial_eval, "get_aval", get_aval2)), ExitStack(): | ||
with Patcher( | ||
(jax._src.interpreters.partial_eval, "get_aval", get_aval2), | ||
(jax._src.lax.slicing, "gather_p", gather2_p), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dime10 @grwlf here we have an issue, the new rule does not define jvp and therefore it is not compatible with grad or jacobian transformations ad.defjvp(gather_p, _gather_jvp_rule, None) see the original gather_p
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
weak_type_rule=_argnum_weak_type(0))
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
pe.padding_rules[gather_p] = _gather_pad_rule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rmoyard does this result in a user-facing bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you cannot use jax.grad inside qjit when a gather operation is created, for example slicing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so #305?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That issue is about Catalyst gradients, I think Romain is talking about JAX gradients (which are run in the frontend on the jaxpr, hence the primitives need gradient rules).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, got it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly what David said, qjit with jax.grad and slicing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qjit(jax.grad(f))(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking if we have a resolution on this particular comment thread :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@grwlf How is the upstream PR looking for this patch?
In the meantime, can we just attach the original gradient rule to the patched primitive?
In this PR we enable indexing for tensors with dynamic shapes. This PR is intended to be merged after the #370
[sc-47632]
The corresponding Jax PR suggests the fix to the upstream.