In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map

In [2]:
def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))

foo
=====
invars: [a]
outvars: [b]
constvars: []
equation: [a, 1] add [b] {}

jaxpr: { lambda  ; a.
  let b = add a 1
  in (b,) }

bar
=====
invars: [a, b, c]
outvars: [g, c]
constvars: []
equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': None}
equation: [d, b] add [e] {}
equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [e, f] add [g] {}

jaxpr: { lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a c
      e = add d b
      f = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      g = add e f
  in (g, c) }


In [3]:
from collections import OrderedDict

def slice_closed_jaxpr(closed_jaxpr, start=None, end=None):
#     print("closed_jaxpr.consts:", closed_jaxpr.consts)
#     print("closed_jaxpr.jaxpr.constvars:", closed_jaxpr.jaxpr.constvars)
#     print("closed_jaxpr.jaxpr.invars:", closed_jaxpr.jaxpr.invars)
#     print("closed_jaxpr.jaxpr.outvars:", closed_jaxpr.jaxpr.outvars)
    invars = set(closed_jaxpr.jaxpr.invars)
    consts_dir = OrderedDict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))
    
    pred_intermediate_vars = set()
    
    slice_consts_dir = OrderedDict()
    slice_invars = []
    slice_outvars = []
    slice_eqns = []
    slice_intermediate_vars = set()

    succ_intermediate_vars = set()
    
    start = start if start is not None else 0
    end = end if end is not None else len(closed_jaxpr.jaxpr.eqns)
    
    for index, eqn in enumerate(closed_jaxpr.jaxpr.eqns):
#         print(index, eqn, eqn.invars, eqn.outvars)
        if index < start:
            pred_intermediate_vars.update(eqn.outvars)
        elif start <= index < end:
            slice_eqns.append(eqn)
            for var in eqn.invars:
                if isinstance(var, core.Literal):
                    continue
                elif var in consts_dir:
                    if var not in slice_consts_dir:
                        slice_consts_dir[var] = consts_dir[var]
                elif (var in invars) or (var in pred_intermediate_vars):
                    if var not in slice_invars: # FIXME: this is O(n^2)
                        slice_invars.append(var)
                else:
                    assert var in slice_intermediate_vars
            slice_intermediate_vars.update(eqn.outvars)
        else:  # end <= index
            for var in eqn.invars:
                if isinstance(var, core.Literal):
                    continue
                elif (var in invars) or (var in pred_intermediate_vars):
                    if var not in slice_invars: # FIXME: this is O(n^2)
                        slice_invars.append(var)
                    if var not in slice_outvars: # FIXME: this is O(n^2)
                        slice_outvars.append(var)
                elif var in slice_intermediate_vars:
                    if var not in slice_outvars: # FIXME: this is O(n^2)
                        slice_outvars.append(var)                    
                else:
                    assert (var in consts_dir) or (var in succ_intermediate_vars)
            succ_intermediate_vars.update(eqn.outvars)

    for var in closed_jaxpr.jaxpr.outvars:
        if (var in invars) or (var in pred_intermediate_vars):
            if var not in slice_invars: # FIXME: this is O(n^2)
                slice_invars.append(var)
            if var not in slice_outvars: # FIXME: this is O(n^2)
                slice_outvars.append(var)
        elif var in slice_intermediate_vars:
            if var not in slice_outvars: # FIXME: this is O(n^2)
                slice_outvars.append(var)                    
        else:
            assert (var in consts_dir) or (var in succ_intermediate_vars)

#     print("pred_intermediate_vars", pred_intermediate_vars)
#     print("slice_consts_dir", slice_consts_dir)
#     print("slice_invars", slice_invars)
#     print("slice_outvars", slice_outvars)
#     print("slice_eqns", slice_eqns)
#     print("slice_intermediate_vars", slice_intermediate_vars)
#     print("succ_intermediate_vars", succ_intermediate_vars)
    slice_jaxpr = core.Jaxpr(slice_consts_dir.keys(), slice_invars, slice_outvars, slice_eqns)
    slice_closed_jaxpr = core.ClosedJaxpr(slice_jaxpr, slice_consts_dir.values())
    return slice_closed_jaxpr

In [4]:
def f(x, z):
    y = jnp.sin(jnp.ones_like(x))
    x = y * jnp.tanh(x)
    x = jnp.sin(x)
    x = jnp.cos(x)
    x = jnp.exp(x)
    return x, z
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5), jnp.ones(6))
closed_jaxpr

{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      d = sin c
      e = tanh a
      f = mul d e
      g = sin f
      h = cos g
      i = exp h
  in (i, b) }

In [5]:
closed_jaxpr_slice1 = slice_closed_jaxpr(closed_jaxpr, start=0, end=4)
closed_jaxpr_slice1

{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(5,) ] 1.0
      d = sin c
      e = tanh a
      f = mul d e
  in (f, b) }

In [6]:
closed_jaxpr_slice2 = slice_closed_jaxpr(closed_jaxpr, start=4)
closed_jaxpr_slice2

{ lambda  ; f b.
  let g = sin f
      h = cos g
      i = exp h
  in (i, b) }

In [7]:
core.jaxpr_as_fun(closed_jaxpr)(jnp.ones(5), jnp.ones(6))

[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [8]:
intermediate = core.jaxpr_as_fun(closed_jaxpr_slice1)(jnp.ones(5), jnp.ones(6))
print(intermediate)
core.jaxpr_as_fun(closed_jaxpr_slice2)(*intermediate)

[DeviceArray([0.6408594, 0.6408594, 0.6408594, 0.6408594, 0.6408594], dtype=float32), DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]


[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [9]:
intermediate = jit(core.jaxpr_as_fun(closed_jaxpr_slice1))(jnp.ones(5), jnp.ones(6))
jit(core.jaxpr_as_fun(closed_jaxpr_slice2))(*intermediate)

[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]

In [10]:
# TODO: merge with Lianmin's code
# TODO: PyTree inputs
# Q: How about lax.cond & lax.while?
#    Ideally we should inline lax.cond & lax.while
# Q: How about backward?
# Q: How to slice a computation into different stages, given that jaxpr is actually a graph?
# Why JaxPR? Try XLA
# Forward & backward device assignment (very general)

In [11]:
def f(w, x):
    x = w @ x
    x = jnp.exp(x)
    return x
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a b
      d = exp c
  in (d,) }

In [12]:
c = jax.xla_computation(f)(jnp.ones((5, 5)), jnp.ones(5))
print(c.as_hlo_text())

HloModule xla_computation_f.7

ENTRY xla_computation_f.7 {
  constant.3 = pred[] constant(false)
  parameter.1 = f32[5,5]{1,0} parameter(0)
  parameter.2 = f32[5]{0} parameter(1)
  dot.4 = f32[5]{0} dot(parameter.1, parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  exponential.5 = f32[5]{0} exponential(dot.4)
  ROOT tuple.6 = (f32[5]{0}) tuple(exponential.5)
}




In [17]:
from jax import core
from jax.lib import xla_client

pipeline_start_p = core.Primitive("pipeline_start")  # Create the primitive

def mark_pipeline_start(name):
    return pipeline_start_p.bind(name=name)

def pipeline_start_impl(name):
    return None

def pipeline_start_abstract_eval(name):
    return core.abstract_unit

def pipeline_start_xla_translation(c, name):
    return xla_client.ops.Constant(c, np.float32(0.0))

pipeline_start_p.def_impl(pipeline_start_impl)
pipeline_start_p.def_abstract_eval(pipeline_start_abstract_eval)
from jax.interpreters import xla
xla.backend_specific_translations['cpu'][pipeline_start_p] = pipeline_start_xla_translation
xla.backend_specific_translations['gpu'][pipeline_start_p] = pipeline_start_xla_translation
xla.backend_specific_translations['tpu'][pipeline_start_p] = pipeline_start_xla_translation


In [18]:

def f(w, x):
    mark_pipeline_start("1")
    x = w @ x
    x = jnp.exp(x)
    return x
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))
closed_jaxpr

{ lambda  ; a b.
  let _ = pipeline_start[ name=1 ] 
      c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a b
      d = exp c
  in (d,) }

In [15]:
jax.jit(f)(jnp.ones((5, 5)), jnp.ones(5))

DeviceArray([148.41316, 148.41316, 148.41316, 148.41316, 148.41316], dtype=float32)

In [16]:
xla_client.ops.Constant?