# Using existing primitive

In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:
def mul_add_lax(x, y, z):
    return jax.lax.add(jax.lax.mul(x, y), z)


def square_add_lax(x, y):
    return mul_add_lax(x, x, y)


square_add_lax(2, 10)

Array(14, dtype=int32, weak_type=True)

In [3]:
jax.grad(square_add_lax)(2.0, 10.0)

Array(4., dtype=float32, weak_type=True)

# Defining the new Primitive

## Create new primitive

In [4]:
mul_add_p = jax.core.Primitive("mul_add")
type(mul_add_p)

jax._src.core.Primitive

In [5]:
for attr in dir(mul_add_p):
    if not attr.startswith("__"):
        print(attr)

abstract_eval
bind
bind_with_trace
call_primitive
def_abstract_eval
def_custom_bind
def_effectful_abstract_eval
def_impl
get_bind_params
impl
map_primitive
multiple_results
name


In [6]:
def mul_add_prim(x,y,z):
    return mul_add_p.bind(x,y,z)

def square_add_prim(x,y):
    return mul_add_prim(x,x,y)

In [7]:
try:
    square_add_prim(2.0,10.)
except Exception as e:
    print(e)

Evaluation rule for 'mul_add' not implemented


## Define the Evaluation Rule and Register to primitive

In [8]:
import numpy as np
def mul_add_impl(x,y,z):
    """Note we are using numpy here"""
    return np.add(np.multiply(x,y),z)

# register Eval rule to primitive
mul_add_p.def_impl(mul_add_impl)

<function __main__.mul_add_impl(x, y, z)>

In [9]:
assert square_add_prim(2.,10.) == 14

In [10]:
try:
    jax.jit(square_add_prim)(2.,10.)
except Exception as e:
    print(e)

Abstract evaluation for 'mul_add' not implemented


## Define the Abstraction Eval rule and Register to the primitive

In [11]:
def mul_add_abstract_eval(xs,ys,zs):
    assert xs.shape == ys.shape
    assert xs.shape == zs.shape
    return jax.core.ShapedArray(xs.shape,xs.dtype)

mul_add_p.def_abstract_eval(mul_add_abstract_eval)

<function __main__.mul_add_abstract_eval(xs, ys, zs)>

In [12]:
try:
    jax.jit(square_add_prim)(2.0,10.)
except Exception as e:
    print(e)

MLIR translation rule for primitive 'mul_add' not found for platform cpu


## XLA Compilation Rule

Jaxpr --> MLIR Lowering[HLO] 

In [13]:
from jax._src.lib.mlir.dialects import hlo
def mul_add_lowering(ctx,xc,yc,zc):
    return [hlo.AddOp(hlo.MulOp(xc,yc),zc).result]

from jax.interpreters import mlir
mlir.register_lowering(mul_add_p,mul_add_lowering,platform="cpu")

<function __main__.mul_add_lowering(ctx, xc, yc, zc)>

In [15]:
jax.jit(square_add_prim)(2.0,10.)

Array(14., dtype=float32)