-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX integration tests #1685
Changes from all commits
43b1fe0
b6e78c6
5e7d680
7f7a65e
653a546
92cb81d
eedeabb
c180bf2
04dfcf5
61a60ef
31eae7c
c9eab59
a8db871
b8e511d
4d50817
eb75c4c
d09a382
db973d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,6 +281,6 @@ def requires_grad(tensor, interface=None): | |
if interface == "jax": | ||
import jax | ||
|
||
return isinstance(tensor, jax.interpreters.ad.JVPTracer) | ||
return isinstance(tensor, jax.core.Tracer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
raise ValueError(f"Argument {tensor} is an unknown object") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -400,7 +400,7 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su | |
coeffs = coeffs.numpy() | ||
|
||
expected = np.dot(qcval, coeffs) | ||
assert np.all(res == expected) | ||
assert np.allclose(res, expected) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, this test was failing for me on CI (but not locally) 🤔 |
||
|
||
def test_unknown_interface(self, monkeypatch): | ||
"""Test exception raised if the interface is unknown""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a great change to have 🥇
How come it's placed here, instead of into the JAX branch of
requires_grad
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it needs to be here, since the
not any
check can only be done here, it cannot be done inside therequires_grad
check (which only checks a single tensor at a time) 🤔I could be wrong though, let me know if you see a way around this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think you're right. 🤔 At least nothing else comes to mind that we could use here.