# Using existing primitive

In [1]:
import jax
import traceback

jax.config.update("jax_platform_name", "cpu")

In [2]:
def mul_add_lax(x, y, z):
    return jax.lax.add(jax.lax.mul(x, y), z)


def square_add_lax(x, y):
    return mul_add_lax(x, x, y)


square_add_lax(2, 10)

Array(14, dtype=int32, weak_type=True)

In [3]:
jax.grad(square_add_lax)(2.0, 10.0)

Array(4., dtype=float32, weak_type=True)

## Copy Paste

In [4]:
#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """A decorator for functions to trace arguments and results."""

  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False

In [5]:
from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

# @trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
  
  Note that the traced arguments must be passed as positional arguments
  to `bind`. 
  """
  return multiply_add_p.bind(x, y, z)

# @trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)

In [6]:
with expectNotImplementedError():
  square_add_prim(2., 10.)


Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_30442/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_30442/1717318103.py", line 16, in square_add_prim
    return multiply_add_prim(a, a, b)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_30442/1717318103.py", line 11, in multiply_add_prim
    return multiply_add_p.bind(x, y, z)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Evaluation rule for 'multiply_add' not implemented


In [7]:
import numpy as np
from jax import lax
from jax._src import api

# @trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.
  Args:
    x, y, z: the concrete arguments of the primitive. Will only be called with 
      concrete values.
  Returns:
    the concrete result of the primitive.
  """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)

<function __main__.multiply_add_impl(x, y, z)>

In [8]:
assert square_add_prim(2., 10.) == 14.

In [9]:
with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)


Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_30442/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/pjit.py", line 257, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
                                                 ^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented


In [10]:
from jax import core
# @trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments. 
  Args:
    xs, ys, zs: abstractions of the arguments.
  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

In [11]:
with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)


Found expected exception:


Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_30442/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  

In [12]:
from jax._src.lib.mlir.dialects import hlo
# @trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')

<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

In [13]:
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.

In [14]:
assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14.

In [15]:
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.))


Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_30442/800067577.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/api.py", line 1945, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/api.py", line 1974, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Differentiation rule for 'multiply_add' not implemented


In [16]:
from jax.interpreters import ad


# @trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents), 
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values 
  for the arguments and tangents.

  Args:
    arg_values: a tuple of arguments
    arg_tangents: a tuple with the tangents of the arguments. The tuple has 
      the same length as the arg_values. Some of the tangents may also be the 
      special value ad.Zero to specify a zero tangent.
  Returns:
     a pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now we have a JAX-traceable computation of the output. 
  # Normally, we can use the ma primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)
  
  _trace("Tangent evaluation:")
  # We must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
  
  # We do need to deal specially with Zero. Here we just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for Zero and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  
  
  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX 
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp

In [17]:
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)

Primal evaluation:
Tangent evaluation:


In [18]:
assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)

Primal evaluation:
Tangent evaluation:


In [19]:
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.)

Primal evaluation:
Tangent evaluation:

Found expected exception:


Traceback (most recent call last):
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 287, in get_primitive_transpose
    return primitive_transposes[p]
           ~~~~~~~~~~~~~~~~~~~~^^^
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Trace

# Defining the new Primitive

## Create new primitive

In [20]:
multiply_add_p = jax.core.Primitive("multiply_add")
type(multiply_add_p)

jax._src.core.Primitive

In [21]:
for attr in dir(multiply_add_p):
    if not attr.startswith("__"):
        print(attr)

abstract_eval
bind
bind_with_trace
call_primitive
def_abstract_eval
def_custom_bind
def_effectful_abstract_eval
def_impl
get_bind_params
impl
map_primitive
multiple_results
name


In [22]:
def multiply_add_prim(x, y, z):
    return multiply_add_p.bind(x, y, z)


def square_add_prim(x, y):
    return multiply_add_prim(x, x, y)

In [23]:
try:
    square_add_prim(2.0, 10.0)
except Exception as e:
    print(e)

Evaluation rule for 'multiply_add' not implemented


## Define the Evaluation Rule and Register to primitive

In [24]:
import numpy as np


def multiply_add_impl(x, y, z):
    """Note we are using numpy here"""
    return np.add(np.multiply(x, y), z)


# register Eval rule to primitive
multiply_add_p.def_impl(multiply_add_impl)

<function __main__.multiply_add_impl(x, y, z)>

In [25]:
assert square_add_prim(2.0, 10.0) == 14

In [26]:
try:
    jax.jit(square_add_prim)(2.0, 10.0)
except Exception as e:
    print(e)

Abstract evaluation for 'multiply_add' not implemented


## Define the Abstraction Eval rule and Register to the primitive

In [27]:
def multiply_add_abstract_eval(xs, ys, zs):
    assert xs.shape == ys.shape
    assert xs.shape == zs.shape
    return jax.core.ShapedArray(xs.shape, xs.dtype)


multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

In [28]:
try:
    jax.jit(square_add_prim)(2.0, 10.0)
except Exception as e:
    print(e)

MLIR translation rule for primitive 'multiply_add' not found for platform cpu


## XLA Compilation Rule

Jaxpr --> MLIR Lowering[HLO] 

In [29]:
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import mlir

def multiply_add_lowering(ctx, xc, yc, zc):
    return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]



mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform="cpu")

<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

In [30]:
jax.jit(square_add_prim)(2.0, 10.0)

Array(14., dtype=float32)

In [31]:
try:
    jax.grad(square_add_prim)(2.0, 10.0)
except Exception as e:
    print(e)

Differentiation rule for 'multiply_add' not implemented


In [32]:
from jax._src.interpreters import ad

dir(ad)

['Any',
 'Callable',
 'CustomJVPException',
 'CustomVJPException',
 'JVPTrace',
 'JVPTracer',
 'Literal',
 'Partial',
 'Primitive',
 'Sequence',
 'Trace',
 'Tracer',
 'UndefinedPrimal',
 'Zero',
 '__annotations__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '_closed_call_transpose',
 '_custom_lin_transpose',
 '_interleave',
 '_jvp_jaxpr',
 '_perm',
 '_primal_tangent_shapes_match',
 '_update_annotation',
 'add_jaxvals',
 'add_jaxvals_p',
 'add_tangents',
 'annotations',
 'as_hashable_function',
 'backward_pass',
 'bilinear_transpose',
 'call_p',
 'call_param_updaters',
 'call_transpose',
 'call_transpose_param_updaters',
 'closed_backward_pass',
 'config',
 'contextlib',
 'core',
 'custom_lin_p',
 'defbilinear',
 'defjvp',
 'defjvp2',
 'defjvp_zero',
 'deflinear',
 'deflinear2',
 'dtype',
 'f_jvp_traceable',
 'flatten_fun',
 'flatten_fun_nokwargs',
 'float0',
 'functools',
 'get_aval',
 'get_primitive_transpose',
 '

## Forward Differentiation

In [33]:
from jax.interpreters import ad

def mul_add_value_and_jvp(arg_value, arg_tangent):
    """Function must jax traceable"""
    x, y, z = arg_value
    xt, yt, zt = arg_tangent
    primal_out = multiply_add_prim(x,y,z)
    def make_zero(tan):
        return jax.lax.zeros_like_array(x) if type(tan) is ad.Zero else tan 
    
    tangent_out = multiply_add_prim(make_zero(xt),y,multiply_add_prim(x,make_zero(yt),make_zero(zt)))
    return (primal_out,tangent_out)

ad.primitive_jvps[multiply_add_p] = mul_add_value_and_jvp

In [34]:
from jax._src import api
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)

In [35]:
assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)

In [36]:
try:
    jax.grad(square_add_prim)(2.,10.)
except Exception:
    traceback.print_exc(limit=5)

Traceback (most recent call last):
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 287, in get_primitive_transpose
    return primitive_transposes[p]
           ~~~~~~~~~~~~~~~~~~~~^^^
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/root/miniconda3/envs/nlp/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode diff

## Reverse Rule

In [37]:
def multiply_add_transpose(ct,x,y,z):
    if not ad.is_undefined_primal(x):
        assert ad.is_undefined_primal(y)
        ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x,ct,jax.lax.zeros_like_array(x))
        res = None,ct_y,ct
    else:
        assert ad.is_undefined_primal(x)
        ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct,y,jax.lax.zeros_like_array(y))
        res = ct_x,None,ct
    return res

ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

In [38]:
jax.grad(square_add_prim)(2.0,10.0)

Array(4., dtype=float32)

## Batching

In [42]:
try:
    jax.vmap(square_add_prim,in_axes=0,out_axes=0)(np.array([2.,3.]),np.array([10.,20.]))
except Exception as e:
    print(e)

Batching rule for 'multiply_add' not implemented


In [43]:
from jax.interpreters import  batching

def mul_add_batch(vector_arg_values,batch_axes):
    assert batch_axes[0] == batch_axes[1]
    assert batch_axes[0] == batch_axes[2]
    res = multiply_add_prim(*vector_arg_values)
    return res,batch_axes[0]

batching.primitive_batchers[multiply_add_p] = mul_add_batch

In [44]:
jax.vmap(square_add_prim,in_axes=0,out_axes=0)(np.array([2.,3.]),np.array([10.,20.]))

array([14., 29.])