In [64]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map

In [65]:
def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))

foo
=====
invars: [a]
outvars: [b]
constvars: []
equation: [a, 1] add [b] {}

jaxpr: { lambda  ; a.
  let b = add a 1
  in (b,) }

bar
=====
invars: [a, b, c]
outvars: [g, c]
constvars: []
equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': None}
equation: [d, b] add [e] {}
equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [e, f] add [g] {}

jaxpr: { lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a c
      e = add d b
      f = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      g = add e f
  in (g, c) }


In [66]:
from collections import OrderedDict

def slice_closed_jaxpr(closed_jaxpr, start=None, end=None):
#     print("closed_jaxpr.consts:", closed_jaxpr.consts)
#     print("closed_jaxpr.jaxpr.constvars:", closed_jaxpr.jaxpr.constvars)
#     print("closed_jaxpr.jaxpr.invars:", closed_jaxpr.jaxpr.invars)
#     print("closed_jaxpr.jaxpr.outvars:", closed_jaxpr.jaxpr.outvars)
    invars = set(closed_jaxpr.jaxpr.invars)
    consts_dir = OrderedDict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))
    
    pred_intermediate_vars = set()
    
    slice_consts_dir = OrderedDict()
    slice_invars = []
    slice_outvars = []
    slice_eqns = []
    slice_intermediate_vars = set()

    succ_intermediate_vars = set()
    
    start = start if start is not None else 0
    end = end if end is not None else len(closed_jaxpr.jaxpr.eqns)
    
    for index, eqn in enumerate(closed_jaxpr.jaxpr.eqns):
#         print(index, eqn, eqn.invars, eqn.outvars)
        if index < start:
            pred_intermediate_vars.update(eqn.outvars)
        elif start <= index < end:
            slice_eqns.append(eqn)
            for var in eqn.invars:
                if isinstance(var, core.Literal):
                    continue
                elif var in consts_dir:
                    if var not in slice_consts_dir:
                        slice_consts_dir[var] = consts_dir[var]
                elif (var in invars) or (var in pred_intermediate_vars):
                    if var not in slice_invars: # FIXME: this is O(n^2)
                        slice_invars.append(var)
                else:
                    assert var in slice_intermediate_vars
            slice_intermediate_vars.update(eqn.outvars)
        else:  # end <= index
            for var in eqn.invars:
                if isinstance(var, core.Literal):
                    continue
                elif (var in invars) or (var in pred_intermediate_vars):
                    if var not in slice_invars: # FIXME: this is O(n^2)
                        slice_invars.append(var)
                    if var not in slice_outvars: # FIXME: this is O(n^2)
                        slice_outvars.append(var)
                elif var in slice_intermediate_vars:
                    if var not in slice_outvars: # FIXME: this is O(n^2)
                        slice_outvars.append(var)                    
                else:
                    assert (var in consts_dir) or (var in succ_intermediate_vars)
            succ_intermediate_vars.update(eqn.outvars)

    for var in closed_jaxpr.jaxpr.outvars:
        if (var in invars) or (var in pred_intermediate_vars):
            if var not in slice_invars: # FIXME: this is O(n^2)
                slice_invars.append(var)
            if var not in slice_outvars: # FIXME: this is O(n^2)
                slice_outvars.append(var)
        elif var in slice_intermediate_vars:
            if var not in slice_outvars: # FIXME: this is O(n^2)
                slice_outvars.append(var)                    
        else:
            assert (var in consts_dir) or (var in succ_intermediate_vars)

#     print("pred_intermediate_vars", pred_intermediate_vars)
#     print("slice_consts_dir", slice_consts_dir)
#     print("slice_invars", slice_invars)
#     print("slice_outvars", slice_outvars)
#     print("slice_eqns", slice_eqns)
#     print("slice_intermediate_vars", slice_intermediate_vars)
#     print("succ_intermediate_vars", succ_intermediate_vars)
    slice_jaxpr = core.Jaxpr(slice_consts_dir.keys(), slice_invars, slice_outvars, slice_eqns)
    slice_closed_jaxpr = core.ClosedJaxpr(slice_jaxpr, slice_consts_dir.values())
    return slice_closed_jaxpr

In [67]:
def f(x, z):
    y = jnp.sin(jnp.ones_like(x))
    x = y * jnp.tanh(x)
    x = jnp.sin(x)
    x = jnp.cos(x)
    x = jnp.exp(x)
    return x, z
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5), jnp.ones(6))
closed_jaxpr

{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      d = sin c
      e = tanh a
      f = mul d e
      g = sin f
      h = cos g
      i = exp h
  in (i, b) }

In [68]:
closed_jaxpr_slice1 = slice_closed_jaxpr(closed_jaxpr, start=0, end=4)
closed_jaxpr_slice1

{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      d = sin c
      e = tanh a
      f = mul d e
  in (f, b) }

In [69]:
closed_jaxpr_slice2 = slice_closed_jaxpr(closed_jaxpr, start=4)
closed_jaxpr_slice2

{ lambda  ; f b.
  let g = sin f
      h = cos g
      i = exp h
  in (i, b) }

In [70]:
core.jaxpr_as_fun(closed_jaxpr)(jnp.ones(5), jnp.ones(6))

[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [71]:
intermediate = core.jaxpr_as_fun(closed_jaxpr_slice1)(jnp.ones(5), jnp.ones(6))
print(intermediate)
core.jaxpr_as_fun(closed_jaxpr_slice2)(*intermediate)

[DeviceArray([0.6408594, 0.6408594, 0.6408594, 0.6408594, 0.6408594], dtype=float32), DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]


[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [72]:
intermediate = jit(core.jaxpr_as_fun(closed_jaxpr_slice1))(jnp.ones(5), jnp.ones(6))
jit(core.jaxpr_as_fun(closed_jaxpr_slice2))(*intermediate)

[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [73]:
# TODO: merge with Lianmin's code
# TODO: PyTree inputs
# Q: How about lax.cond & lax.while?
#    Ideally we should inline lax.cond & lax.while
# Q: How about backward?
# Q: How to slice a computation into different stages, given that jaxpr is actually a graph?
# Why JaxPR? Try XLA
# Forward & backward device assignment (very general)

In [74]:
@jax.jit
def matmul(w, x):
    return w @ x

def f(w, x):
    x = matmul(w, x)
    x = jnp.exp(x)
    return x

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                                                      precision=None
                                                      preferred_element_type=None ] a b
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=matmul ] a b
      d = exp c
  in (d,) }

In [75]:
with jax.disable_jit():
    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a b
      d = exp c
  in (d,) }

In [139]:
from jax import core
from jax.lib import xla_client
from jax.interpreters import xla, ad

pipeline_start_p = core.Primitive("pipeline_start")  # Create the primitive
pipeline_start_p.multiple_results = True
pipeline_end_p = core.Primitive("pipeline_end")  # Create the primitive
pipeline_end_p.multiple_results = True

def mark_pipeline_start(*args, name):
    return pipeline_start_p.bind(*args, name=name)

def mark_pipeline_end(*args, name):
    return pipeline_end_p.bind(*args, name=name)


def pipeline_impl(*args, name):
    if len(args) == 0:
        return (None, )
    else:
        return args

def pipeline_abstract_eval(*args, name):
    if len(args) == 0:
        return (core.abstract_unit, )
    else:
        return args

def pipeline_xla_translation(c, *args, name):
    if len(args) == 0:
        return xla_client.ops.Tuple(c, (xla_client.ops.Constant(c, np.float32(0.0)), ))
        # xla_client.ops.Constant(c, np.float32(0.0))
    else:
        return xla_client.ops.Tuple(c, args)

def pipeline_start_transpose(*args):
    print("pipeline_start_transpose", args)
    raise(NotImplementedError())
    
    
pipeline_start_p.def_impl(pipeline_impl)
pipeline_start_p.def_abstract_eval(pipeline_abstract_eval)
xla.backend_specific_translations['cpu'][pipeline_start_p] = pipeline_xla_translation
xla.backend_specific_translations['gpu'][pipeline_start_p] = pipeline_xla_translation
xla.backend_specific_translations['tpu'][pipeline_start_p] = pipeline_xla_translation
ad.primitive_transposes[pipeline_start_p] = pipeline_start_transpose

pipeline_end_p.def_impl(pipeline_impl)
pipeline_end_p.def_abstract_eval(pipeline_abstract_eval)
xla.backend_specific_translations['cpu'][pipeline_end_p] = pipeline_xla_translation
xla.backend_specific_translations['gpu'][pipeline_end_p] = pipeline_xla_translation
xla.backend_specific_translations['tpu'][pipeline_end_p] = pipeline_xla_translation
ad.primitive_transposes[pipeline_end_p] = pipeline_start_transpose


In [142]:
def f_original(w, x):
    x = matmul(w, x)
    x = jnp.exp(x)
    x = jnp.sum(x)
    y = 7 * x
    return x

def f(w, x):
    w, x = mark_pipeline_start(w, x, name="1")
    x = matmul(w, x)
    mark_pipeline_end(name="1")
    x, = mark_pipeline_start(x, name="2")
    x = jnp.exp(x)
    x = jnp.sum(x)
    y = 7 * x
    x, = mark_pipeline_end(x, name="2")
    return x
with jax.disable_jit():
    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

{ lambda  ; a b.
  let c d = pipeline_start[ name=1 ] a b
      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] c d
      _ = pipeline_end[ name=1 ] 
      f = pipeline_start[ name=2 ] e
      g = exp f
      h = reduce_sum[ axes=(0,) ] g
      _ = mul h 7.0
      i = pipeline_end[ name=2 ] h
  in (i,) }

In [143]:
jax.jit(f)(jnp.ones((5, 5)), jnp.ones(5))

DeviceArray(742.0658, dtype=float32)

In [144]:
with jax.disable_jit():
    closed_jaxpr = jax.make_jaxpr(jax.grad(f, argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

NotImplementedError: Differentiation rule for 'pipeline_start' not implemented

In [81]:
(jax.grad(f, argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))

(DeviceArray([[148.41316, 148.41316, 148.41316, 148.41316, 148.41316],
              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],
              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],
              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],
              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316]],            dtype=float32),
 DeviceArray([742.0658, 742.0658, 742.0658, 742.0658, 742.0658], dtype=float32))

In [86]:
closed_jaxpr.jaxpr.eqns[9].source_info.frames

[current;/home/ubuntu/efs/jax/jax/_src/source_info_util.py:65,
 default_process_primitive;/home/ubuntu/efs/jax/jax/interpreters/partial_eval.py:151,
 process_primitive;/home/ubuntu/efs/jax/jax/interpreters/partial_eval.py:140,
 bind;/home/ubuntu/efs/jax/jax/core.py:258,
 mul;/home/ubuntu/efs/jax/jax/_src/lax/lax.py:350,
 <lambda>;/home/ubuntu/efs/jax/jax/_src/lax/lax.py:2265,
 <genexpr>;/home/ubuntu/efs/jax/jax/interpreters/ad.py:461,
 standard_jvp2;/home/ubuntu/efs/jax/jax/interpreters/ad.py:462,
 process_primitive;/home/ubuntu/efs/jax/jax/interpreters/ad.py:274,
 bind;/home/ubuntu/efs/jax/jax/core.py:258,
 exp;/home/ubuntu/efs/jax/jax/_src/lax/lax.py:187,
 <lambda>;/home/ubuntu/efs/jax/jax/_src/numpy/lax_numpy.py:405,
 f;<ipython-input-77-6e51657d2cc0>:13,
 call_wrapped;/home/ubuntu/efs/jax/jax/linear_util.py:166,
 trace_to_jaxpr;/home/ubuntu/efs/jax/jax/interpreters/partial_eval.py:498,
 linearize;/home/ubuntu/efs/jax/jax/interpreters/ad.py:101,
 vjp;/home/ubuntu/efs/jax/jax/interpr

In [130]:
xla_client.ops.Tuple?