In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

8.12 ms ± 572 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

23.5 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

8.67 ms ± 41.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

359 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

75.3 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
@jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

76 µs ± 1.84 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [10]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [11]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

In [12]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [17]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.4 ms ± 11.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Array([[  7.6967783 ,  -1.2004204 ,  -7.8025274 , ...,  -9.198534  ,
         -4.636133  ,   5.0350065 ],
       [ -3.7661114 ,  22.081488  ,  15.895488  , ..., -10.457462  ,
         11.74389   ,  -1.0312204 ],
       [ -6.269892  ,   0.04147506, -10.759269  , ...,   6.7663145 ,
          0.24011374,  -4.183442  ],
       ...,
       [  1.8337102 ,  10.578623  ,  17.425434  , ...,   6.059695  ,
          0.08220291,  -3.881513  ],
       [-15.693324  ,  18.875978  ,  10.051355  , ...,  -1.5857544 ,
        -14.738107  ,  13.019888  ],
       [  3.3215528 ,  -0.0837307 ,  15.172005  , ...,   7.6211243 ,
         -6.221728  , -15.446035  ]], dtype=float32)

In [19]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
batched_apply_matrix(batched_x).block_until_ready()

Manually batched
26.4 µs ± 312 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Array([[  7.6967764 ,  -1.2004201 ,  -7.8025265 , ...,  -9.198536  ,
         -4.636134  ,   5.0350065 ],
       [ -3.7661119 ,  22.081488  ,  15.895488  , ..., -10.457462  ,
         11.743892  ,  -1.0312209 ],
       [ -6.269892  ,   0.04147583, -10.75927   , ...,   6.7663155 ,
          0.24011254,  -4.183442  ],
       ...,
       [  1.8337111 ,  10.578621  ,  17.425432  , ...,   6.0596957 ,
          0.08220172,  -3.8815124 ],
       [-15.693325  ,  18.875978  ,  10.051354  , ...,  -1.5857532 ,
        -14.738108  ,  13.019889  ],
       [  3.3215542 ,  -0.08372951,  15.172006  , ...,   7.6211233 ,
         -6.221728  , -15.446035  ]], dtype=float32)

In [20]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
25.4 µs ± 728 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Array([[  7.6967764 ,  -1.2004201 ,  -7.8025265 , ...,  -9.198536  ,
         -4.636134  ,   5.0350065 ],
       [ -3.7661119 ,  22.081488  ,  15.895488  , ..., -10.457462  ,
         11.743892  ,  -1.0312209 ],
       [ -6.269892  ,   0.04147583, -10.75927   , ...,   6.7663155 ,
          0.24011254,  -4.183442  ],
       ...,
       [  1.8337111 ,  10.578621  ,  17.425432  , ...,   6.0596957 ,
          0.08220172,  -3.8815124 ],
       [-15.693325  ,  18.875978  ,  10.051354  , ...,  -1.5857532 ,
        -14.738108  ,  13.019889  ],
       [  3.3215542 ,  -0.08372951,  15.172006  , ...,   7.6211233 ,
         -6.221728  , -15.446035  ]], dtype=float32)

In [23]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([10.145861 ,  5.298538 ,  5.1421375], dtype=float32)