In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

10.9 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

26.8 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jax import device_put

x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

11 ms ± 100 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit onp.dot(x, x.T)

211 ms ± 2.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

In [8]:
x = random.normal(key, (1_000_000,))

In [9]:
%timeit selu(x).block_until_ready()

1.46 ms ± 422 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

110 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
2.05 * 1000 / 110

18.636363636363637

In [12]:
def sum_logistic(x):
    return np.sum(1.0 / (1.0 + np.exp(-x)))

x_small = np.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [13]:
def first_finite_difference(f, x, eps=1e-3):
    return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in np.eye(len(x))])

print(first_finite_difference(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [14]:
print(
    grad(
        jit(
            grad(
                jit(
                    grad(
                        sum_logistic
                    )
                )
            )
        )
    )(1.0)
)

-0.03532559


In [15]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

In [16]:
hessian(sum_logistic)(x_small)

DeviceArray([[-0.        , -0.        , -0.        ],
             [-0.        , -0.09085777, -0.        ],
             [-0.        , -0.        , -0.07996248]], dtype=float32)

In [17]:
x_small

array([0., 1., 2.])

In [18]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

In [19]:
def apply_matrix(v):
    return np.dot(mat, v)

In [20]:
def naively_batched_apply_matrix(v_batched):
    return np.stack([apply_matrix(v) for v in v_batched])

In [21]:
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.37 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
batched_x[:3, :5]

DeviceArray([[-1.5721827 , -0.5877473 , -0.15419528, -0.68907833,
               0.659211  ],
             [ 0.23441708,  1.2395371 , -1.0322325 , -0.97337836,
               1.4354849 ],
             [-0.39501408,  1.1313331 ,  0.77410907, -2.1247523 ,
              -0.92607194]], dtype=float32)

In [23]:
naively_batched_apply_matrix(batched_x).block_until_ready()

DeviceArray([[  7.6967783 ,  -1.2004204 ,  -7.8025274 , ...,
               -9.198534  ,  -4.6361337 ,   5.0350065 ],
             [ -3.7661114 ,  22.081488  ,  15.895488  , ...,
              -10.457462  ,  11.743891  ,  -1.0312204 ],
             [ -6.269892  ,   0.04147506, -10.759269  , ...,
                6.7663145 ,   0.24011374,  -4.183442  ],
             ...,
             [  1.8337102 ,  10.578623  ,  17.425434  , ...,
                6.059695  ,   0.08220291,  -3.881513  ],
             [-15.693324  ,  18.875978  ,  10.051355  , ...,
               -1.5857544 , -14.738106  ,  13.019888  ],
             [  3.3215523 ,  -0.08373046,  15.172007  , ...,
                7.6211233 ,  -6.221728  , -15.446037  ]], dtype=float32)

In [24]:
@jit
def batched_apply_matrix(v_batched):
    return np.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
110 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [25]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
107 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
