# 5.3. JAX Internals and Primitives

> **Author**: Gustavo Leite / **Date**: March 2022.

TODO

In [1]:
# STANDARD LIBRARY IMPORTS
# ===============================================================================================

from typing import NamedTuple

# LOCAL IMPORTS
# ===============================================================================================

# Utilitary code used in this notebook
# Provides functions for tracing and logging
# Source: ./support/tracing.py
import util.tracing as tracing

# EXTERNAL IMPORTS
# ===============================================================================================

import numpy as np      # Standard Numpy
import jax.numpy as jnp # JAX Numpy namespace that mimics standard Numpy

from jax import (
    jit,         # Transformation for just-in-time compiling
    grad,        # Transformation for computing the gradient (i.e. derivative)
    jvp,         # Transformation for jacobian-vector product (forward autodifferentiation)
    vjp,         # Transformation for vector-jacobian product (backward autodifferentiation)
    vmap,        # Transformation for vectorization
    pmap,        # Transformation for parallelization
    make_jaxpr,  # Transformation for dumping the corresponding JAXPR
    lax,         # Namespace where default primitives are defined
)

from jax.core import (
    Primitive,   # Object denoting a primitive operation
    ShapedArray, # Object denoting an abstract array with a shape and a type
)

from jax._src.lib import (
    xla_client,  # Namespace that allows us to interact with the XLA Python API
)

from jax.interpreters import (
    xla,         # Register primitives for compilation with XLA
    ad,          # Register primitives for auto-differentiation
    batching,    # Register primitives for batching
)

## PART 0: Introduction

JAX offers a series of composable function transformations like JIT compilation, auto-differentiation, auto-vectorization, and more. In order to carry out these transformations, JAX needs to know about the "structure" of the function being transformed. For instance, take the case of autodiff: JAX needs to know, somehow, which mathematical operations are being applied to each input parameter in order to be able to compute its derivative analitically. This evokes an idea of an abstract representation of the function, like a *computation graph*, that relates inputs and operations to the outputs of the function. The operations in this computation graph are called **Primitives** and they will be the main object of study in this notebook.

The next logical question is: how do we obtain such computation graph? If JAX was part of the Python interpreter, it could simply get a reference of the Abstract Syntax Tree (AST) of the function and make transformations based on it. However, JAX is distributed as a library therefore it needs another mechanism of obtaining the computation graph: **tracing**. The tracing mechanism consists of calling the function to be transformed with special objects in place of the inputs. These special objects overload the mathematical operations in Python and a able to "record" everything that is done to them. This record of operatins is precisely what we a looking for: a computation graph. In JAX terms, this computation graph is calleda JAXPR, short for *JAX Expression*. The special objects used in place of the real parameters are called Tracers. We will talk more about them later.

The next cell shows an extremely simplified example of tracing a function by overloading binary operators. The class `AbstractVar` represents an abstract variable: it has a name but its not have a value; the class `Expression` represents a recursive structure of computations performed on abstract variables; and the class `Tracer` manages to create the structure of expressions for us.

In [2]:
# In this cell we recreate the JAX tracing mechanism in a extremely simplified way.

class AbstractVar(NamedTuple):
    """Class that represents an abstract variable."""
    name: str
    

class Expression(NamedTuple):
    """Class that represents a binary operation."""
    op: str
    lhs: object
    rhs: object
        
        
class Tracer(object):
    """Tracer object that `records` operations."""
    value: object
        
    def __init__(self, value):
        self.value = value
        
    def __repr__(self):
        return repr(self.value)
        
    # Overload the '+' (addition) operation
    def __add__(self, rhs):
        return Tracer(Expression(op="add", lhs=self, rhs=rhs))
    
    # Overload the '*' (multiplication) operation
    def __mul__(self, rhs):
        return Tracer(Expression(op="mul", lhs=self, rhs=rhs))
    
    # Overload the '**' (power) operation
    def __pow__(self, rhs):
        return Tracer(Expression(op="pow", lhs=self, rhs=rhs))
    
    # Ideally, we would overload ALL of python operators, this is exactly what JAX does.
    # We overload only a few of them here to keep the example small.
    ...

    
def my_function(x, y):
    """This is the function we would like to trace.
    
    Note how it is pure: we comput the output from the inputs only."""
    return x ** 2. + x * y


# Let's test with real inputs
# ----------------------------------------------------------------------------

print("Invocation with real inputs:\n==> ", end="")
result = my_function(2., 10.)  # Call function with regular floats
print(result)

# Now test with Tracer inputs
# ----------------------------------------------------------------------------

print("\nInvocation with Tracer inputs:\n==> ", end="")
X = Tracer(AbstractVar("x"))   # Create a tracer to represent input `x`
Y = Tracer(AbstractVar("y"))   # Create a tracer to represent input `y`
result = my_function(X, Y)     # Call our function with tracers instead of numbers
print(result)  

Invocation with real inputs:
==> 24.0

Invocation with Tracer inputs:
==> Expression(op='add', lhs=Expression(op='pow', lhs=AbstractVar(name='x'), rhs=2.0), rhs=Expression(op='mul', lhs=AbstractVar(name='x'), rhs=AbstractVar(name='y')))


Take a moment to understand what this example is doing. On the first invocation of `my_function` we passed floating points as parameters and got another floating point as output, just as we expect. On the second invocation, however, we passed `Tracer` objects and it returned another tracer that encapsulates the recursive structure of operations. Isn't is interesting that we can discover the structure of a function without really computing it?

The recursive structure returned in the second invocations can be understood as a directed acyclic graph (DAG).

<center>
    <br />
    <img src="images/expression_tree.png" alt="Expression tree." width="60%" />
</center>

> **Note**: Curious readers who would like to know more about the tracing mechanism are referred to this [notebook](https://jax.readthedocs.io/en/latest/autodidax.html) in the JAX Docs.

One limitation of this tracing approach, is that JAX can only trace mathematical operations like addition, multiplication, division, etc. It cannot, for instance, detect when a variable is printed to standard output or to a file. In this sense, JAX expects functions to be *pure*. In other words, the function being transformed should produce its output based solely on its inputs and not depend on any external implicit value (*e.g.* global variables, files, etc). Another way to put it is to say that functions must be *free of side-effects*.

When the computation graph is available, JAX can traverse it and do whatever it needs to do: compile operations if doing JIT, differentiating expressions if doing autodiff, you get the idea. Based on this scenario, we will explore how can we expand the default operations offered by JAX and create new primitives that are compatible with existing transformations.

In the PART 1 of this document we will discuss..

**Table of Contents**

1. JAX Primitives and JIT
2. JAX Primitives and Autodiff
3. JAX Primitives and Batching
4. Summary

<hr />

## PART 1: JAX Primitives and JIT


### Preparing the Ground

Say that we want to implement a function `square_add(a, b)` that returns $a^2 + b$ and we want to use the `multiply_add(x, y, z): return x * y + z` function to implement `square_add`. Using JAX Numpy, we could write these functions like this:

In [3]:
@tracing.log_calls
def multiply_add_numpy(x, y, z):
    """Multiply `x` by `y` and add `z`."""
    return jnp.add(jnp.multiply(x, y), z)
#          ~~~     ~~~
#           `-------`--> We are using JAX Numpy functions but
#                        they do exactly the same thing as the
#                        standard numpy functions `np.add` and
#                        `np.multiply`.

@tracing.log_calls
def square_add_numpy(a, b):
    """Square `a` and add `b`"""
    return multiply_add_numpy(a, a, b)

> **Note**: The decorator `@tracing.log_calls` will log the invocations of the function being decorated to stdout. Nested invocations are shown with indentation for improved readability, you will see what that means in the next cell.

Let us now call our new `square_add_numpy` function:

In [4]:
square_add_numpy(2., 10.)

| CALL square_add_numpy(2.0, 10.0)
|   | CALL multiply_add_numpy(2.0, 2.0, 10.0)
|   | RET  multiply_add_numpy = 14.0
| RET  square_add_numpy = 14.0


DeviceArray(14., dtype=float32, weak_type=True)

Immediately, some things come to our attention:

1. The nested calling and returning from functions is dumped to standard output;
2. As expected, `square_add_numpy` receives two floats as input (2 and 10) and returns another float (14);
3. The result of this computation is an object of type `DeviceArray`.

Apart from the formatted call stack, this is exactly what we expect from evaluating this function. Now let us try using the `grad` operator.

In [5]:
grad(square_add_numpy)(2., 10.)

| CALL square_add_numpy(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_numpy(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | RET  multiply_add_numpy = Tracer<ConcreteArray(14.0, dtype=float32, weak_type=True)>
| RET  square_add_numpy = Tracer<ConcreteArray(14.0, dtype=float32, weak_type=True)>


DeviceArray(4., dtype=float32, weak_type=True)

New observations can be made:

1. There is our call stack again;
2. The floating point parameters we passed were wrapped in a `Tracer<ConcreteArray(...)>`; (note something familiar?)
3. Returned number of 4.0 which is exactly the derivative of `square_add(a,b)` with respect to `a`.

Remember the example of the introduction? Well, the `Tracer` and `ConcreteArray` objects are some of the mechanisms that allow the tracing to occurr inside JAX. The `Tracer` class has the exact same purpose as our initial example and the `ConcreteArray` is called and abstract value, and it serves as a generalization of the `Variable` class from the example in the introduction. There are many types of abstract values in JAX and each of those types correspond to how much we know about a variable of a value:

- `ConcreteArray`s mean that we know the type, the shape and the CONTENTS of the variable.
- `ShapedArray`s mean that we know only the type and shape, but not the contents.

Similarly for `Tracer`s, there are a couple of them in JAX, each one specialized for a particular transformation. The class diagram below gathers a non-exhaustive list of the tracers and abstract values found in JAX source code. Classes highlighted in orange work as abstract base classes.

<center>
    <br />
    <img src="images/avals_and_tracers.png" alt="Abstract values and tracers in JAX" width="50%" />
</center>

<hr />

Up until this point we only used JAX built-in features to write our functions. Now we are going to recreate the `multiply_add` operation from the ground up, *i.e.* without using `jnp.*` functions. For that need to create a new `Primitive`!

### Creating a new Primitive

A primitive is simply an object with a name which represents an operation in JAX.

In [6]:
# Instantiate our new primitive
multiply_add_p = Primitive("multiply_add")

@tracing.log_calls
def multiply_add_prim(x, y, z):
    """The JAX-traceable way to use the JAX primitive.
    
    This is the function that will be exposed to the user."""
    return multiply_add_p.bind(x, y, z)
#                         ~~~~
#                           `-> This `bind` is the important part!

@tracing.log_calls
def square_add_prim(a, b):
    """The JAX-traceable way to use the JAX primitive.
    
    This is another function that will be exposed to the user."""
    return multiply_add_p.bind(a, a, b)

In the above cell we `bind` the inputs to the `Primitive`. Binding means that we take the concrete inputs from the `multiply_add_prim` function and wrap them with `Tracer`s and `AbstractValue`s. For reasons that will become clear in a moment, binding primitives is the entry point for all transformations and the actual type of the tracer and abstract value determines which kind of information will be extracted during tracing of this primitive.

With out primitive instantiated, let's try to call our new function.

In [7]:
with tracing.ExpectNotImplemented():
    square_add_prim(2., 10.)

| CALL square_add_prim(2.0, 10.0)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_6900/367851666.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/scratch/research/notebooks/gustavo.leite/util/tracing.py", line 53, in fn_wrapper
    res = fn(*args)
  File "/tmp/ipykernel_6900/126395883.py", line 18, in square_add_prim
    return multiply_add_p.bind(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented


We got an error! Pay attention to the error message:

> **NotImplementedError**: Evaluation rule for 'multiply_add' not implemented

This error was expected. Even though we created a new primitive and gave it a name, we never told JAX what this primitive actually computes! We should define an *implementation* for this primitive, or as JAX put it, an *evaluation rule*.

In [8]:
@tracing.log_calls
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does NOT need to be JAX traceable.
  """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)
#           ~~~    ~~~~~~~~
#            `--------`--> Same computation as before, but using standard Numpy.


# Register the implementation of `multiply_add_p` primitive.
multiply_add_p.def_impl(multiply_add_impl)

<function __main__.multiply_add_impl(x, y, z)>

Now we can call our new function that binds on our primitive! Please be mindful that we are not doing JIT yet. This is simply delegating the computation to standard Numpy as is done inside the `multiply_add_impl` function.

In [9]:
assert square_add_prim(2., 10.) == 14.

| CALL square_add_prim(2.0, 10.0)
|   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   | RET  multiply_add_impl = 14.0
| RET  square_add_prim = 14.0


Be mindful that this evaluation rule is executed on the CPU no matter what backend you are using. In order to run code on other devices we need to generate code for them. That will come later! 😉

Let's review our intuition at this point. Our functions are called in this order (arrows go from caller to callee).

<center>
    <br />
    <img src="images/primitive_magic.png" alt="Primitive magic" width="80%" />
</center>

What magic 🪄✨ is going on here? The magic is called `Trace`! (Not to be confused with `Tracer` with an **r**, this is confusing, I know 🤷). A Trace is an object in the JAX core that process a primitive. There are many types of Traces and each of those roughly correspond to a specific transformation.

- `EvalTrace` is responsible for simply calling the primitive implementation;
- `DynamicJaxprTrace` is responsible for doing abstract evaluation (more on that later);
- `JVPTrace` is responsible for taking the gradient;
- `BatchTrace` is responsible for vectorization and parallelization;
- etc.

Because we are not doing any transformations (yet), JAX defaults to the `EvalTrace` that simply calls the primitive implementation for us. Our intuition is now updated.

<center>
    <br />
    <img src="images/primitive_eval.png" alt="Primitive magic" width="80%" />
</center>

In fact, JAX keeps a stack of traces. When we bind a primitive, it simply takes whatever trace is active (on top of the stack) and calls the `process_primitive` on that trace passing the primitive being binded. The following diagram gives a high-level overview of this process. While processing a primitive, other primitives may be found and the process goes on recursively, until we have traced the entire function.

<center>
    <br />
    <img src="images/primitive_full.png" alt="Primitive magic" width="80%" />
</center>

Having understood that, we can update the class diagram of the JAX core classes. The `Trace`s are the ones responsible for instantiating `Tracer`s of the correct type.


<center>
    <br />
    <img src="images/jax_core_simple.png" alt="Primitive magic" width="100%" />
</center>

<hr />

### Just-in-Time Compiling our Primitive

So far so good, we created a new primitive and gave it an evaluation rule. What if we decide to accelerate our code by JIT compiling our function? Let's try.

In [10]:
with tracing.ExpectNotImplemented():
    jit(square_add_prim)(2., 10.)

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_6900/2077843376.py", line 2, in <module>
    jit(square_add_prim)(2., 10.)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 430, in cache_miss
    out_flat = xla.xla_call(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented


We get an error... JAX complains that we have not defined an abstract evaluation rule. Such abstract evaluation rule is used to compute the shape and data type of the output given the shape and data type of the inputs. Remember that JAX compiles your function with the specific input shape in mind, if you try to call the function with inputs of another size, it will trigger a recompilation.

We define the **abstract evaluation** rule in the next cell.

In [11]:
@tracing.log_calls
def multiply_add_abstract_eval(xs, ys, zs):
    """Abstractly evaluate our primitive based on the input shapes."""
    # Assert that all input parameters have the same shape
    assert xs.shape == ys.shape
    assert xs.shape == zs.shape
    # Inform that the output has the same shape as the inputs
    return ShapedArray(xs.shape, xs.dtype)


# Register the abstract implementation of `multiply_add_p` primitive.
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

With that out of the way, let's try calling our primitive one more time.

In [12]:
with tracing.ExpectNotImplemented():
    jit(square_add_prim)(2., 10.)

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

Found expected exception:


Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 191, in _run_module_as_main
    msg = "%s: %s" % (sys.executable, exc)
  File "/usr/lib/python3.10/runpy.py", line 75, in _run_code
    fname = mod_spec.origin
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 12, in <module>
    if sys.path[0] == '':
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform gpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_6900/2077843376.py", line 2, in <module>
    jit(square_add_prim)(2., 10.)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 16

Another error! Now JAX is upset because it doesn't known how to translate our function to an IR XLA understands so that the function can finally be compiled to native executable binary. This is aptly called a translation rule. We use the XLA API o create and *add* and *mul* instruction. Let's check it out:

In [13]:
# Alias for convenience
xops = xla_client.ops

@tracing.log_calls
def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):
    """Translate our function to XLA's intermediate-representation."""
    return [xops.Add(xops.Mul(xc, yc), zc)]
#           ~~~~~~~~ ~~~~~~~~
#               `--------`--> We are creating the IR for our function here!


# Associate the new translation rule with our primitive
# We use the same translation function for both CPU and GPU, but we need to inform them separately.
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='gpu')

Finally, JIT should be working now.

In [14]:
assert jit(square_add_prim)(2., 10.) == 14.

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
| CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
| RET  multiply_add_abstract_eval = ShapedArray(float32[])
| CALL multiply_add_xla_translation(TranslationContext(builder=<jaxlib.xla_extension.XlaBuilder object at 0x7ff6e416b370>, platform='gpu', axis_env=AxisEnv(nreps=1, names=(), sizes=()), name_stack=''), [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), Shaped

As you can see, invoking our primitive causes it first to be abstract evaluated, then translated using the rule we defined and finally executed on the device. We could also optionally inform JAX that some parameters to this function are static using the `static_argnums=1`. We are saying that the parameter with index 1 of the function `square_add_prim` is static and should not be traced. Compare the first line of the trace above with the first line of the trace below. In the former, `square_add_prim` is invoked with two tracers while on the latter the second parameter is a literal value `10.0`.

In [15]:
assert jit(square_add_prim, static_argnums=1)(2., 10.) == 14.

| CALL square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 10.0)
|   | CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|   | RET  multiply_add_abstract_eval = ShapedArray(float32[])
| RET  square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
| CALL multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
| RET  multiply_add_abstract_eval = ShapedArray(float32[])
| CALL multiply_add_xla_translation(TranslationContext(builder=<jaxlib.xla_extension.XlaBuilder object at 0x7ff6e4173630>, platform='gpu', axis_env=AxisEnv(nreps=1, names=(), sizes=()), name_stack=''), [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], [ShapedArray(float32[])], <XlaOp at 0x7ff

We are done! At least with the JIT transformation. Let us compare the two implementations: (a) the one that uses regular `jnp.*` functions; and the (b) the second that implements a new primitive.

Here is the JAXPR and MLIR of (a).

In [16]:
with tracing.SuppressCallLog():
    print(make_jaxpr(square_add_numpy)(2., 10.))
    print("=" * 80)
    print(jit(square_add_numpy).lower(2., 10.).compiler_ir('mhlo'))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = mul a a; d:f32[] = add c b in (d,) }
module @jit_square_add_numpy.11 {
  func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<f32>
    %1 = mhlo.add %0, %arg1 : tensor<f32>
    return %1 : tensor<f32>
  }
}



And here is the JAXPR and MLIR of (b). Note how instead of `add` and `mul` operations we have a single `multiply_add` in the JAXPR.

In [17]:
with tracing.SuppressCallLog():
    print(make_jaxpr(square_add_prim)(2., 10.))
    print("=" * 80)
    print(jit(square_add_prim).lower(2., 10.).compiler_ir('mhlo'))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = multiply_add a a b in (c,) }
module @jit_square_add_prim.12 {
  func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = call @multiply_add(%arg0, %arg0, %arg1) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func private @multiply_add(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
    %0 = call @xla_fallback_multiply_add(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func private @xla_fallback_multiply_add(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.constant dense<false> : tensor<i1>
    %1 = mhlo.multiply %arg0, %arg1 : tensor<f32>
    %2 = mhlo.add %1, %arg2 : tensor<f32>
    return %2 : tensor<f32>
  }
}



### What have we done so far...

The following table summarizes what we needed to do in order to implement a new primitive and make it work with the `jit` transformation.

| Step | Description | API |
|:----:|:------------|:----|
| 1 | Create new primitive object | `Primitive(NAME)` |
| 2 | Define concrete evaluation rule | `PRIMITIVE.def_impl(IMPL_FN)` |
| 3 | Define abstract evaluation rule | `PRIMITIVE.def_abstract_eval(ABS_EVAL_FN)` |
| 4 | Define translation rule | `xla.register_translation(PRIMITIVE, TRANSLATION_FN)` |

In the next section we will enable our primitive to be used with `grad`.

<hr />

## PART 2: JAX Primitives and Autodiff

> **Note**: This section is incomplete. You can stop reading here for now.

In [18]:
with tracing.ExpectNotImplemented():
    jvp(square_add_prim, (2., 10.), (1., 1.))

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(10.0, dtype=float32, weak_type=True)>)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_6900/635092508.py", line 2, in <module>
    jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 2280, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 2309, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented


We have this function:

$$ f(x, y, z) = x \cdot y + z $$

The first partial derivatives are:

$$
    \nabla f = (y, x, 1)
$$

And we are computing its JVP as:

$$
    \vec{t} \cdot \nabla f = \nabla_\vec{t} f = x_t y + y_t x + z_t
$$

What is happening here?

In [19]:
@tracing.log_calls
def multiply_add_value_and_jvp(arg_values, arg_tangents):
    """Evaluates the primal output and the tangents (Jacobian-vector product).

    Given values of the arguments and perturbation of the arguments (tangents), 
    compute the output of the primitive and the perturbation of the output.

    This method must be JAX-traceable. JAX may invoke it with abstract values 
    for the arguments and tangents.
    """
    x,  y,  z  = arg_values
    xt, yt, zt = arg_tangents
    
    tracing.log(">>> Primal evaluation:")
    primal_out = multiply_add_prim(x, y, z)
    
    def make_zero(tan):
        return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
    
    tracing.log(">>> Tangent evaluation:")
    output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
    
    return primal_out, output_tangent

# Register JVP rule for out `multiply_add_p` primitive
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp

In [20]:
assert jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
|   | CALL multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(2.0, 2.0, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(2.0, 1.0, 1.0)
|   |   |   | CALL multiply_add_impl(2.0, 1.0, 1.0)
|   |   |   | RET  multiply_add_impl = 3.0
|   |   | RET  multiply_add_prim = 3.0
|   |   | CALL multiply_add_prim(1.0, 2.0, 3.0)
|   |   |   | CALL multiply_add_impl(1.0, 2.0, 3.0)
|   |   |   | RET  multiply_add_impl = 5.0
|   |   | RET  multiply_add_prim = 5.0
|   | RET  multiply_add_value_and_jvp = (14.0, 5.0)
| RET  square_add_prim = Tracer<ConcreteArray(14.0, dtype=float32)>


In [21]:
with tracing.ExpectNotImplemented():
  grad(square_add_prim)(2., 10.)

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_value_and_jvp((Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0), (Tracer<ShapedArray(float32[], weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, 0.0)
|   |   |   | CALL multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), Shape

Traceback (most recent call last):
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/interpreters/ad.py", line 258, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 191, in _run_module_as_main
    msg = "%s: %s" % (sys.executable, exc)
  File "/usr/lib/python3.10/runpy.py", line 75, in _run_code
    fname = mod_spec.origin
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 12, in <module>
    if sys.path[0] == '':
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The ab

In [22]:
@tracing.log_calls
def multiply_add_transpose(ct, x, y, z):
    """Evaluates the transpose of a linear primitive.

    This method is only used when computing the backward gradient following 
    value_and_jvp, and is only needed for primitives that are used in the JVP 
    calculation for some other primitive. We need transposition for multiply_add_prim, 
    because we have used multiply_add_prim in the computation of the output_tangent in 
    multiply_add_value_and_jvp.

    In our case, multiply_add is not a linear primitive. However, it is used linearly 
    w.r.t. tangents in multiply_add_value_and_jvp:
       output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
  
    Always one of the first two multiplicative arguments is a constant.
    """
    if not ad.is_undefined_primal(x):
        # This use of multiply_add is with a constant "x"
        assert ad.is_undefined_primal(y)
        ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
        return None, ct_y, ct
    else:
        # This use of multiply_add is with a constant "y"
        assert ad.is_undefined_primal(x)
        ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
        return ct_x, None, ct
        
# Register transpose rule for `multiply_add_p` primitive.
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

In [23]:
assert grad(square_add_prim)(2., 10.) == 4.

| CALL square_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   | CALL multiply_add_value_and_jvp((Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0), (Tracer<ShapedArray(float32[], weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
|   |   >>> Primal evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|   |   |   | CALL multiply_add_impl(2.0, 2.0, 10.0)
|   |   |   | RET  multiply_add_impl = 14.0
|   |   | RET  multiply_add_prim = 14.0
|   |   >>> Tangent evaluation:
|   |   | CALL multiply_add_prim(Tracer<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Tracer<ShapedArray(float32[], weak_type=True)>, 0.0)
|   |   |   | CALL multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), Shape

### RECAP

Consider a composition of functions $f(g(h(\dots(z(x))))$. We normally evaluate such a function from inside-out, first calculate $z(x)$ then use the result as input to the next innermost function until we get to $f$. Similarly for evaluating functions, we decide to compute the derivative from inside-out or outside-in. These are known as forward mode autodiff and backward mode autodiff, respectively.

<br>

<center>
    <img src="images/fwd_bwd_ad.png", alt="autodiff" width="30%" />
</center>

- Is this correct or is it inverted?

<hr />

## PART 3: JAX Primitives and Batching

In [24]:
a = np.array([2., 3.])
b = np.array([10., 20.])

# The arguments are two vectors instead of two scalars
with tracing.ExpectNotImplemented():
    vmap(square_add_prim, in_axes=0, out_axes=0)(a, b)

| CALL square_add_prim(Tracer<ShapedArray(float32[])>, Tracer<ShapedArray(float32[])>)

Found expected exception:


Traceback (most recent call last):
  File "/tmp/ipykernel_6900/44637082.py", line 6, in <module>
    vmap(square_add_prim, in_axes=0, out_axes=0)(a, b)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/scratch/research/notebooks/venv/lib/python3.10/site-packages/jax/_src/api.py", line 1555, in batched_fun
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented


In [25]:
@tracing.log_calls
def multiply_add_batch(vector_arg_values, batch_axes):
    """Computes the batched version of the primitive.

    This must be a JAX-traceable function.

    Since the multiply_add primitive already operates pointwise on arbitrary
    dimension tensors, to batch it we can use the primitive itself. This works as
    long as both the inputs have the same dimensions and are batched along the
    same axes. The result is batched along the axis that the inputs are batched.

    Args:
    vector_arg_values: a tuple of two arguments, each being a tensor of matching
      shape.
    batch_axes: the axes that are being batched. See vmap documentation.
    Returns:
    a tuple of the result, and the result axis that was batched. 
    """
    assert batch_axes[0] == batch_axes[1]
    assert batch_axes[0] == batch_axes[2]
    tracing.log(">>> Using multiply_add to compute the batch")
    res = multiply_add_prim(*vector_arg_values)
    return res, batch_axes[0]


batching.primitive_batchers[multiply_add_p] = multiply_add_batch

In [26]:
vmap(square_add_prim, in_axes=0, out_axes=0)(a, b)

| CALL square_add_prim(Tracer<ShapedArray(float32[])>, Tracer<ShapedArray(float32[])>)
|   | CALL multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
|   |   >>> Using multiply_add to compute the batch
|   |   | CALL multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
|   |   |   | CALL multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
|   |   |   | RET  multiply_add_impl = [14. 29.]
|   |   | RET  multiply_add_prim = [14. 29.]
|   | RET  multiply_add_batch = ([14. 29.], 0)
| RET  square_add_prim = Tracer<ShapedArray(float32[])>


array([14., 29.])

### RECAP

<center>
    <img src="images/batching.png" alt="batching" width="50%" />
</center>

<hr />

## PART 4: Summary

1. Create a primitive:

```python
my_primitive_p = Primitive("my_primitive")
```

2. Define a concrete evaluation rule for interpretation:

```python
my_primitive_p.def_impl(my_primitive_impl)
```

3. Define abstract evaluation rule for JIT compilation:

```python
my_primitive_p.def_abstract_eval(my_primitive_abs_eval)
```

4. Define XLA translation rule for JIT compilation:

```python
jax.interpreters.xla.register_translation(
    my_primitive_p,
    my_primitive_xla_compile_gpu,
    platform='gpu')
```

5. Define Jacobian-Vector Product (JVP) rule for forward AD:

```python
jax.interpreters.ad.primitive_jvps[my_primitive_p] = my_primitive_value_and_jvp
```

6. Define transpose rule for backward AD:

```python
jax.interpreters.ad.primitive_transposes[my_primitive_p] = my_primitive_transpose
```

7. Define batching rule:

```python
jax.interpreters.batching.primitive_batchers[my_primitive_p] = my_primitive_batch
```

### JAX Core Reference

This is an attempt of a more or less complete picture of the core classes of JAX.

<center>
    <br />
    <img src="images/jax_core_full.png" alt="" width="100%"/>
</center>