In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
# Random Data Generator
key = random.PRNGKey(0)
x = random.normal(key, (10, ))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


2024-01-10 16:30:22.170281: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

965 µs ± 815 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

9.32 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
# Adding a matrix onto the GPU
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

971 µs ± 25.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Useful Core Jax Features
- `jit()` for speeding up your code
- `grad()` for taking derivatives
- `vmap()` for automatic vectorization or batching

In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

249 µs ± 259 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
#Using @jit
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

36.4 µs ± 62.6 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [11]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [12]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [13]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
992 µs ± 24 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
45.8 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
43.9 µs ± 176 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
