# Mixed Precision in JAX

## JMP
https://github.com/deepmind/jmp

In [1]:
import jax
import jax.numpy as jnp
import jmp

In [2]:
half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

In [3]:
my_policy = jmp.Policy(
    compute_dtype=half,
    param_dtype=full,
    output_dtype=half
)

# alternative syntax
# my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")

In [4]:
# def layer(params, x):
#   params, x = my_policy.cast_to_compute((params, x))
#   w, b = params
#   y = x @ w + b
#   return my_policy.cast_to_output(y)

# params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
# y = layer(params, x)
# assert y.dtype == half

In [5]:
from functools import partial
from jax import random
key = random.PRNGKey(42)

N = 10000
x = random.normal(key, shape=(N,N))
w = random.normal(key, shape=(N,N))

def dot(x,w, my_policy):
    x,w = my_policy.cast_to_compute((x,w))
    y = x @ w
    return my_policy.cast_to_output(y)

from jax import jit 
dot_compiled = jit(partial(dot,my_policy=my_policy))

In [12]:
%timeit y = dot(x,w,my_policy)

223 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
%timeit y = dot_compiled(x,w)

74.4 µs ± 545 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).
float32 = jmp.get_policy("float32")

In [15]:
dot_compiled_half = jit(partial(dot,my_policy=float16))

In [16]:
%timeit y = dot_compiled_half(x,w)

73.2 µs ± 2.42 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
dot_compiled_full = jit(partial(dot,my_policy=float32))

In [21]:
%timeit y = dot_compiled_full(x,w)

205 µs ± 6.39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
x_half, w_half = half.cast_to_param((x,w))
%timeit dot_compiled_half(x_half,w_half)

49.7 µs ± 764 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [23]:
%timeit y = dot_compiled_full(x_half,w_half)

233 µs ± 7.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
@jit
def dot(x,w, my_policy):
    x,w = my_policy.cast_to_compute((x,w))
    y = x @ w
    return my_policy.cast_to_output(y)

In [7]:
dot_ = partial(dot, my_policy=my_policy)

In [8]:
%timeit dot_(x,w)

TypeError: Argument 'Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.float16'>, output_dtype=<class 'jax.numpy.float16'>)' of type <class 'jmp._src.policy.Policy'> is not a valid JAX type.