# Quick start

In [8]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

3.83 ms ± 289 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

12 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x) # put nd array to gpu
%timeit jnp.dot(x, x.T).block_until_ready()

3.42 ms ± 24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## jit

In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

6.16 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
selu_jit = jit(selu) # (@jit decoration) compile multiple operation together with XLA
%timeit selu_jit(x).block_until_ready()

58.6 µs ± 512 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## grad

In [3]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [4]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [5]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


In [None]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

## vmap

In [8]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [9]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.06 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
The slowest run took 4.13 times longer than the fastest. This could mean that an intermediate result is being cached.
261 µs ± 154 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Manually batched
57.8 µs ± 9.76 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
107 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Think in JAX

In [13]:
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [14]:
y = x.at[0].set(10)
print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]


JAX > LAX > XLA  
">" = build on, higher level, less strict

In [6]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
print(jnp.convolve(x, y))

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
print(result[0, 0])

[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]


## jit, static&traces

In [9]:
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)
norm_compiled = jit(norm) # require static shape

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

True

In [10]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

369 µs ± 9.64 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
85.9 µs ± 6.14 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
def get_negatives(x):
  return x[x < 0] # not static

x = jnp.array(np.random.randn(10))
get_negatives(x) # op-by-op mode
# jit(get_negatives)(x) # jit mode

DeviceArray([-0.16529104, -0.5616348 , -0.02158209, -0.37602973,
             -0.82070136], dtype=float32)

In [19]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}") # tracer object
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([ 5.291437 ,  2.4938517, 10.846224 ], dtype=float32)

In [20]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print(f(x2, y2)) # nothing is printed as it has been compiled. Same input size = no-recompilation
print(f(x2, np.random.randn(4,1))) # shape change = recompile

[6.805405  6.7191615 3.2264812]
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4,1])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3,1])>with<DynamicJaxprTrace(level=0/1)>
[[2.3823314]
 [1.3786898]
 [1.6461018]]


In [21]:
from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c d
  in (e,) }

In [22]:
@jit
def f(x, neg):
  return -x if neg else x # Op flow contain branching that depends on input.

f(1, True)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_1479/2422663986.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'neg'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [23]:
from functools import partial

@partial(jit, static_argnums=(1,)) # casting arg to static
def f(x, neg):
  print(neg)
  return -x if neg else x

f(1, True)

True


DeviceArray(-1, dtype=int32, weak_type=True)

In [25]:
f(2, True)
f(2, False) # re-compile since static arg is changed

False


DeviceArray(2, dtype=int32, weak_type=True)

**Key Concepts:**
* Just as values can be either static or traced, operations can be static or traced.
* Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
* *Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced.

In [26]:
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod()) # reshape require static.

x = jnp.ones((2, 3))
f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=0/1)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>


In [27]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),)) # reshape require static so use numpy!

f(x)

DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)

In [37]:
def f(x, y):
    # print(x)
    # print(y)
    return jnp.dot(x,y)
jit_f = jit(f)
size = 101
x = random.normal(key, (size,size))
y = random.normal(key, (size,size))
# print(f(x,y))
# print(jit_f(x,y))
%timeit f(x,y)
%timeit jit_f(x,y)

39.2 µs ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
52.2 µs ± 5.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
