Skip to content

Commit

Permalink
Add in_parts and out_parts optional arguments `jax.xla_computatio…
Browse files Browse the repository at this point in the history
…n`. (google#3771)

This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`.
  • Loading branch information
skye authored and NeilGirdhar committed Jul 24, 2020
1 parent a1397e6 commit 0163277
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
21 changes: 19 additions & 2 deletions jax/api.py
Expand Up @@ -226,6 +226,7 @@ def _jit_is_disabled():
def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
in_parts=None, out_parts=None,
backend: Optional[str] = None,
tuple_args: bool = False,
instantiate_const_outputs: bool = True) -> Callable:
Expand All @@ -240,6 +241,12 @@ def xla_computation(fun: Callable,
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`. See the examples below.
in_parts: Optional, how each argument to ``fun`` should partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
out_parts: Optional, how each output of ``fun`` should partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Expand Down Expand Up @@ -348,6 +355,8 @@ def computation_maker(*args, **kwargs):
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
wrapped, _ = argnums_partial(wrapped, dyn_argnums, args)
jax_args, in_tree = tree_flatten((args, kwargs))
in_parts_flat = tuple(flatten_axes("xla_computation in_parts",
in_tree.children()[0], in_parts))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
avals = map(abstractify, jax_args)
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
Expand All @@ -356,13 +365,21 @@ def computation_maker(*args, **kwargs):
stage_out=True)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
out_parts_flat = flatten_axes("xla_computation out_parts",
out_tree(), out_parts)
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
xla_consts = map(partial(xb.constant, c), consts)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
xla_args = xla._xla_callable_args(c, avals, tuple_args,
partitions=in_parts_flat)
outs = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
return c.build(xc.ops.Tuple(c, outs))
build_out_tuple = partial(xc.ops.Tuple, c, outs)
if out_parts is not None:
out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)
else:
out_tuple = build_out_tuple()
return c.build(out_tuple)
return computation_maker

def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
Expand Down
31 changes: 31 additions & 0 deletions tests/api_test.py
Expand Up @@ -35,6 +35,7 @@
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.lib import xla_bridge as xb
from jax import test_util as jtu
from jax import tree_util
Expand Down Expand Up @@ -932,6 +933,36 @@ def f(x, y):
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
self.assertIn('constant(3)', xla_comp.as_hlo_text())

def test_xla_computation_partitioned(self):
def f(x, y):
return jnp.dot(x, y) + 1

x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
out_parts=P(4, 1))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={devices=[4,1]0,1,2,3}', hlo_text)

def test_xla_computation_replicated_and_partitioned(self):
def f(x, y):
return jnp.dot(x, y), lax.psum(x, 'i')

x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
axis_env = [('i', 4)]
xla_comp = api.xla_computation(f, axis_env=axis_env,
in_parts=(P(2, 2), None),
out_parts=(P(4, 1), None))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('all-reduce', hlo_text)
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)

def test_jit_device(self):
device = xb.devices()[-1]
x = api.jit(lambda x: x, device=device)(3.)
Expand Down

0 comments on commit 0163277

Please sign in to comment.