jit(), for speeding up your code    
grad(), for taking derivatives   
vmap(), for automatic vectorization or batching.   

JAX vs. NumPy
+ JAX provides a NumPy-inspired interface for convenience.
+ Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
+ Unlike NumPy arrays, JAX arrays are always immutable.

NumPy, lax & XLA: JAX API layering
+ jax.numpy is a high-level wrapper that provides a familiar interface.
+ jax.lax is a lower-level API that is stricter and often more powerful.
+ All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

To JIT or not to JIT
+ By default JAX executes operations one at a time, in sequence.
+ Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
+ Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

JIT mechanics: tracing and static variables
+ JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.
+ Variables that you don’t want to be traced can be marked as static

Static vs Traced Operations
+ Just as values can be either static or traced, operations can be static or traced.
+ Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
+ Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced.

In [1]:
import gymnasium
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import lax

import numpy as np

In [2]:
key = random.PRNGKey(0)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit -n10 -r3 jnp.dot(x, x.T).block_until_ready()  

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit -n10 -r3 jnp.dot(x, x.T).block_until_ready()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


232 ms ± 5.93 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
222 ms ± 13.9 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [3]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

7.28 ms ± 788 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

1.54 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [6]:
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                      for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [7]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [8]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))


def apply_matrix(v):
    return jnp.dot(mat, v)


In [9]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.43 ms ± 85.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
15.1 µs ± 351 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [11]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
34.3 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
# jnp.convolve(x, y)

result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1, ),
    padding=[(len(y) - 1, len(y) - 1)
             ])  # equivalent of padding='full' in NumPy
result[0, 0]

Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [13]:
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

norm_compiled = jit(norm)
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

644 µs ± 47.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
547 µs ± 57.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
def get_negatives(x):
    return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)

Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

In [15]:
# error jit 数组不能变化（编译时确定）
# jit(get_negatives)(x)

In [16]:
@jit
def f(x, y):
    print("Running f():")
    print(f"  x = {x}")
    print(f"  y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"  result = {result}")
    return result


x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

In [17]:
from jax import make_jaxpr

def f(x, y):
    return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c d
  in (e,) }

In [18]:
from functools import partial

@partial(jit, static_argnums=(1, )) # 标记静态参数
def f(x, neg):
    return -x if neg else x

f(1, True), f(1, False)

(Array(-1, dtype=int32, weak_type=True), Array(1, dtype=int32, weak_type=True))

In [19]:
# @jit
# def f(x):
#     return x.reshape(jnp.array(x.shape).prod())


# x = jnp.ones((2, 3))
# f(x)

@jit
def f(x):
    return x.reshape((np.prod(x.shape), ))


f(x)

Array([ 0.24124517, -1.2571421 , -0.48511598, -0.9863928 ,  1.3978302 ,
        0.48784977,  1.9099641 , -0.26037157, -0.49505737,  1.3445066 ,
        0.5942803 ,  0.61083764], dtype=float32)

In [20]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x


# The side-effects appear during the first run
print("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print("Third call, different type: ",
      jit(impure_print_side_effect)(jnp.array([5.])))


Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [21]:
g = 0.


def impure_uses_globals(x):
    return x + g


# JAX captures the value of the global during the first run
print("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print("Third call, different type: ",
      jit(impure_uses_globals)(jnp.array([4.])))


First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [22]:
g = 0.


def impure_saves_global(x):
    global g
    g = x
    return x


# JAX runs once the transformed function with special Traced values for arguments
print("First call: ", jit(impure_saves_global)(4.))
print("Saved global: ", g)  # Saved global has an internal JAX value


First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [23]:
def pure_uses_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state['even' if i % 2 == 0 else 'odd'] += x
    return state['even'] + state['odd']


print(jit(pure_uses_internal_state)(5.))

50.0


In [30]:
@jit
def f(x, y):
    a = x * y
    b = (x + y) / (x - y)
    c = a + 2
    return a + b * c


x = jnp.array([2., 0.])
y = jnp.array([3., 0.])
f(x, y)

Array([-34.,  nan], dtype=float32)

In [31]:
x = random.uniform(random.PRNGKey(0), (1000, ), dtype=jnp.float64)
x.dtype

dtype('float32')