Skip to content

Commit

Permalink
try jax.experimental.host_callback.call again (not working though)
Browse files Browse the repository at this point in the history
  • Loading branch information
antalszava committed Nov 16, 2022
1 parent e2d4cf7 commit 1c8f85d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions pennylane/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def execute_wrapper_jvp(primals, tangents):
new_tapes = [_copy_tape(t, a) for t, a in zip(tapes, params)]
if isinstance(gradient_fn, qml.gradients.gradient_transform):

def wrapper(params, dys):
def wrapper(params):
params, dys = params
multi_measurements = [len(tape.measurements) > 1 for tape in tapes]
with qml.tape.Unwrap(*new_tapes):
jvp_tapes, processing_fn = qml.gradients.batch_jvp(
Expand All @@ -445,13 +446,20 @@ def wrapper(params, dys):
return jacs

total_params = np.sum([len(p) for p in params])
shape_dtype_structs = jax.ShapeDtypeStruct((1,), dtype)
jvps = jax.pure_callback(wrapper, [shape_dtype_structs], params, tangents[0])
shape_dtype_structs = jax.ShapeDtypeStruct((), dtype)

from jax.experimental.host_callback import call
#jacs = call(jacs_wrapper, , paramsshapes)
#jvps = jax.pure_callback(wrapper, [shape_dtype_structs], params, tangents[0])

params = params, tangents[0]
print(params)
jvps = call(wrapper, params, result_shape=[shape_dtype_structs])

res1 = execute_wrapper(params)
res2 = [jvps[0][0]]
print(res1, res2)
res2 = [jvps[0]]

print(res1, res2)
return res1, res2

return execute_wrapper(params)
Expand Down
2 changes: 1 addition & 1 deletion tests/returntypes/test_jax_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def cost(a, device):
interface=interface,
)[0]

res = jax.grad(cost)(a, device=dev)
res = jax.jacrev(cost)(a, device=dev)

for args in spy.call_args_list:
assert args[1]["shifts"] == [(np.pi / 4,)] * 2
Expand Down

0 comments on commit 1c8f85d

Please sign in to comment.