-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: grad #4
WIP: grad #4
Conversation
@@ -732,7 +732,7 @@ def _infeed_p() -> Quantity: | |||
|
|||
@register(lax.integer_pow_p) | |||
def _integer_pow_p(x: Quantity, *, y: Any) -> Quantity: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _integer_pow_p(x: Quantity, *, y: Any) -> Quantity: | |
def _integer_pow_p(x: Quantity, *, y: int) -> Quantity: |
def _mul_p_vq(x: Value, y: Quantity) -> Quantity: | ||
return Quantity(lax.mul(x, y.value), unit=y.unit) | ||
def _mul_p_vq(x: DenseArrayValue, y: Quantity) -> Quantity: | ||
return Quantity(lax.mul(x.array, y.value), unit=y.unit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that the .array
is necessary. I was getting an infinite loop, but didn't investigate whether this was the best solution.
src/jax_quantity/array_namespace.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole file will hopefully not be necessary if patrick-kidger/quax#5 is resolved successfully.
1f37169
to
f89a10d
Compare
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Get Quantity working with
jax.grad
.