Skip to content

Commit

Permalink
Improved batching rule for vjp solve
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Mar 3, 2020
1 parent a98313d commit 341b878
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions jaxfenics/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax.api import defjvp_all

import functools
import itertools

from .helpers import (
numpy_to_fenics,
Expand Down Expand Up @@ -132,13 +133,8 @@ def vjp_fun1_batch(vector_arg_values, batch_axes):
This must be a JAX-traceable function.
Since the vjp_fun1 primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
vector_arg_values: a tuple of arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
Expand All @@ -149,13 +145,13 @@ def vjp_fun1_batch(vector_arg_values, batch_axes):
batch_axes[0] == 0
) # assert that batch axis is zero, need to rewrite for a general case?
# compute function row-by-row
res = np.asarray(
[
vjp_fun1(vector_arg_values[0][i])
for i in range(vector_arg_values[0].shape[0])
]
)
return [res[:, i] for i in range(len(args))], (batch_axes[0],) * len(args)
res = [
vjp_fun1(vector_arg_values[0][i])
for i in range(vector_arg_values[0].shape[0])
]
# transpose resulting list
res_T = list(itertools.zip_longest(*res))
return tuple(map(np.vstack, res_T)), (batch_axes[0],) * len(args)

jax.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

Expand Down

0 comments on commit 341b878

Please sign in to comment.