# Quick start

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax
import numpy as np

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

3.83 ms ± 289 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

12 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x) # put nd array to gpu
%timeit jnp.dot(x, x.T).block_until_ready()

3.42 ms ± 24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## jit

In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

6.16 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
selu_jit = jit(selu) # (@jit decoration) compile multiple operation together with XLA
%timeit selu_jit(x).block_until_ready()

58.6 µs ± 512 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## grad

In [3]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [4]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [5]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [None]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

## vmap

In [8]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [9]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.06 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
The slowest run took 4.13 times longer than the fastest. This could mean that an intermediate result is being cached.
261 µs ± 154 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Manually batched
57.8 µs ± 9.76 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
107 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Think in JAX

In [13]:
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [14]:
y = x.at[0].set(10)
print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]


JAX > LAX > XLA  
">" = build on, higher level, less strict

In [6]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
print(jnp.convolve(x, y))

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
print(result[0, 0])

[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]


## jit, static&traces

In [9]:
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)
norm_compiled = jit(norm) # require static shape

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

True

In [10]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

369 µs ± 9.64 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
85.9 µs ± 6.14 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
def get_negatives(x):
  return x[x < 0] # not static

x = jnp.array(np.random.randn(10))
get_negatives(x) # op-by-op mode
# jit(get_negatives)(x) # jit mode

DeviceArray([-0.16529104, -0.5616348 , -0.02158209, -0.37602973,
             -0.82070136], dtype=float32)

In [19]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}") # tracer object
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([ 5.291437 ,  2.4938517, 10.846224 ], dtype=float32)

In [20]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print(f(x2, y2)) # nothing is printed as it has been compiled. Same input size = no-recompilation
print(f(x2, np.random.randn(4,1))) # shape change = recompile

[6.805405  6.7191615 3.2264812]
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4,1])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3,1])>with<DynamicJaxprTrace(level=0/1)>
[[2.3823314]
 [1.3786898]
 [1.6461018]]


In [21]:
from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c d
  in (e,) }

In [22]:
@jit
def f(x, neg):
  return -x if neg else x # Op flow contain branching that depends on input.

f(1, True)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_1479/2422663986.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'neg'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [23]:
from functools import partial

@partial(jit, static_argnums=(1,)) # casting arg to static
def f(x, neg):
  print(neg)
  return -x if neg else x

f(1, True)

True


DeviceArray(-1, dtype=int32, weak_type=True)

In [25]:
f(2, True)
f(2, False) # re-compile since static arg is changed

False


DeviceArray(2, dtype=int32, weak_type=True)

**Key Concepts:**
* Just as values can be either static or traced, operations can be static or traced.
* Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
* *Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced.

In [2]:
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod()) # reshape require static.

x = jnp.ones((2, 3))
f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=0/1)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>


In [3]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),)) # reshape require static so use numpy!

f(x)

DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)

In [13]:
def f(x, y):
    # print(x)
    # print(y)
    return jnp.dot(x,y)
jit_f = jit(f)
size = 101
x = random.normal(key, (size,size))
y = random.normal(key, (size,size))
# print(f(x,y))
# print(jit_f(x,y))
%timeit f(x,y)
%timeit jit_f(x,y)

The slowest run took 5.67 times longer than the fastest. This could mean that an intermediate result is being cached.
54.5 µs ± 41 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
74.3 µs ± 8.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
jax.make_jaxpr(f)(x,y)

NameError: name 'y' is not defined

# JAX constraints

In [5]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

### Pure funcitons
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. **A pure function will always return the same result if invoked with the same inputs.**

In [6]:
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect 
  return x

# The side-effects appear during the first run  
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [7]:
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [12]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value
g = 10. # probably different g from the saved trace
print ("Second call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Second call:  4.0
Saved global:  10.0


In [13]:
def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

50.0


Incorrect attempts to use iterators with JAX:

In [17]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)  
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))    
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0


### In-Place Updates

In [18]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


In [19]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
try:
  jax_array[1, :] = 1.0
except Exception as e:
  print("Exception {}".format(e))

Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [21]:
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
print("original array unchanged:\n", jax_array)

updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [22]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


### out of bounds indexing
If operate at out-of-bounds index:  
`Updates` will be skipped  
`Retrievals` will be clamped  

In [23]:
# numpy
try:
  np.arange(10)[11]
except Exception as e:
  print("Exception {}".format(e))

Exception index 11 is out of bounds for axis 0 with size 10


In [24]:
jnp.arange(10)[11]

DeviceArray(9, dtype=int32)

### Non (j)np array  
In functions, always pass (j)np array as arguments. Not python's list.

In [25]:
# works but very bad. Each element in x will be traced separately.
def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)

DeviceArray(45, dtype=int32)

In [28]:
#Do this instead (convert before passing):
jnp.sum(jnp.array(x))

DeviceArray(45, dtype=int32)

### RANDOM Numbers

In [None]:
#pseudo random
print(np.random.random())
print(np.random.random())
print(np.random.random())

np.random.seed(0)
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers..., 
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state) 
# --> ('MT19937', array([1499117434, 2949980591, 2242547484, 
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

JAX *explicit* PRNG

In [30]:
from jax import random
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

In [31]:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)

[-0.20584226]
[0 0]
[-0.20584226]
[0 0]


In [32]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)


old key [0 0]
    \---SPLIT --> new key    [4146024105  967050713]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]


In [33]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))

[1.4544677]
[-0.49327764]
[-0.39947626]


### Control Flow

In [None]:
# grad to regular python function. Use python control flow.
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!

In [34]:
@jit 
def f(x): # OK
  for i in range(3):
    x = 2 * x
  return x

@jit
def g(x): # OK
  y = 0.
  for i in range(x.shape[0]): # get unroll and turn to trace operation
    y = y + x[i]
  return y

@jit
def f(x): # BAD
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))

Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_914/2039591909.py:14 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


By default, jit traces your code on the ShapedArray abstraction level, ... fixed shape and dtype.

In [36]:
def f(x): # BAD
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x
f = jit(f, static_argnums=(0,))

print(f(2.))

12.0


In [38]:
def f(x, n):
  y = 0.
  for i in range(n): # 
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)

DeviceArray(5., dtype=float32)

In [None]:
def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))

bad_example_jit = jit(example_fun)
# this will fail:
try:
  print(bad_example_jit(10, 4))
except Exception as e:
  print("Exception {}".format(e))
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))

In [39]:
@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


DeviceArray(4, dtype=int32, weak_type=True)

`lax.cond` differentiable

`lax.while_loop` fwd-mode-differentiable

`lax.fori_loop` fwd-mode-differentiable in general; fwd and rev-mode differentiable if endpoints are static.

`lax.scan` differentiable

In [40]:
from jax import lax

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)

DeviceArray([-1.], dtype=float32)

In [42]:
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
  
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

DeviceArray(10, dtype=int32, weak_type=True)

In [43]:
def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val

  init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)

DeviceArray(45, dtype=int32, weak_type=True)