In [1]:
import numpy as np

import jax.numpy as jnp
from jax import jit

# JIT 

In [2]:
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

In [3]:
norm_compiled = jit(norm)

In [4]:
@jit
def norm_compiled_2(X):
    X = X - X.mean(0)
    return X / X.std(0)

In [5]:
X = jnp.array(np.random.rand(10000, 10))

In [6]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
%timeit norm_compiled_2(X).block_until_ready()

# getting ~same time on cpu

356 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
328 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
331 µs ± 3.69 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Tracing

In [7]:
@jit
def f(x, y):
    print("Running f():")
    print(f"  x = {x}")
    print(f"  y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"  result = {result}")
    return result

In [8]:
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([-0.19564901,  0.43656248, -1.0710025 ], dtype=float32)

In [9]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

Array([ 3.9986274 ,  7.8108335 , -0.76711214], dtype=float32)

In [10]:
x3 = np.random.randn(2, 4)
y2 = np.random.randn(4)
f(x3, y2)

Running f():
  x = Traced<ShapedArray(float32[2,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>


Array([4.1383476, 5.5366507], dtype=float32)

# Control Flow

In [11]:
def f(x, n):
    y = 0.
    for i in range(n):
        y = y + x[i]
    return y

# In effect, the loop gets statically unrolled. 
f = jit(f, static_argnums=(1,))

In [12]:
f(jnp.array([2., 3., 4.]), 2)

Array(5., dtype=float32)