# Mixed Precision in JAX

## JMP
https://github.com/deepmind/jmp

In [1]:
import jax
import jax.numpy as jnp
import jmp

In [2]:
half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

In [3]:
my_policy = jmp.Policy(
    compute_dtype=half,
    param_dtype=full,
    output_dtype=half
)

# alternative syntax
# my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")

In [4]:
# def layer(params, x):
#   params, x = my_policy.cast_to_compute((params, x))
#   w, b = params
#   y = x @ w + b
#   return my_policy.cast_to_output(y)

# params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
# y = layer(params, x)
# assert y.dtype == half

In [1]:
from functools import partial
from jax import random
key = random.PRNGKey(42)
key, subkey = jax.random.split(key)

N = 10000
x = random.normal(key, shape=(N,N))
w = random.normal(subkey, shape=(N,N))

def dot(x,w, my_policy):
    x,w = my_policy.cast_to_compute((x,w))
    y = x @ w
    return my_policy.cast_to_output(y)

from jax import jit 
dot_compiled = jit(partial(dot,my_policy=my_policy))

NameError: name 'jax' is not defined

In [None]:
%timeit y = dot(x,w,my_policy)

26 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [31]:
%timeit y = dot_compiled(x,w)

26.1 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [32]:
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).
float32 = jmp.get_policy("float32")

In [33]:
dot_compiled_half = jit(partial(dot,my_policy=float16))

In [36]:
%timeit y = dot_compiled_half(x,w)

26.2 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
dot_compiled_full = jit(partial(dot,my_policy=float32))

In [38]:
%time y = dot_compiled_full(x,w)

CPU times: user 24.9 ms, sys: 0 ns, total: 24.9 ms
Wall time: 21.6 ms


In [39]:
x_half, w_half = half.cast_to_param((x,w))
%time dot_compiled_half(x_half,w_half)

CPU times: user 22.3 ms, sys: 315 µs, total: 22.6 ms
Wall time: 19.4 ms


DeviceArray([[ -43.06 ,   16.53 ,  148.1  , ...,  -47.53 ,  -90.5  ,
                15.734],
             [-139.6  , -150.5  ,  192.6  , ...,  101.06 ,   80.94 ,
                37.6  ],
             [-201.9  ,   19.1  ,  141.5  , ...,   32.97 , -119.44 ,
               224.5  ],
             ...,
             [-122.2  ,  -51.03 ,   90.75 , ...,  102.5  ,  -60.66 ,
               -12.29 ],
             [ -48.88 ,   65.1  ,   13.36 , ...,   18.5  ,   43.   ,
              -136.4  ],
             [ -29.25 ,  122.56 ,   26.1  , ..., -136.8  ,   82.56 ,
                19.28 ]], dtype=float16)

In [40]:
%time y = dot_compiled_full(x_half,w_half)

CPU times: user 59.9 ms, sys: 4.49 ms, total: 64.4 ms
Wall time: 53.7 ms


In [41]:
@jit
def dot(x,w, my_policy):
    x,w = my_policy.cast_to_compute((x,w))
    y = x @ w
    return my_policy.cast_to_output(y)

In [44]:
dot_ = jit(partial(dot,my_policy=my_policy))

In [45]:
%timeit dot_(x,w)

TypeError: Argument 'Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.float16'>, output_dtype=<class 'jax.numpy.float16'>)' of type <class 'jmp._src.policy.Policy'> is not a valid JAX type.