In [11]:
# Unlike numpy arrays, jax arrays are always immutable
# Numpy arrays are mutable

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [4]:
import jax
import jax.numpy as jnp

In [9]:
type(jnp.linspace(0, 100, 51))

jaxlib.xla_extension.ArrayImpl

In [10]:
# Python's duck-typing allows jax and numpy arrays to be used interchangably in many places.

In [13]:
# Jax array update
x = jnp.array([2,3,4])
y = x.at[2].set(5)

In [14]:
x, y

(Array([2, 3, 4], dtype=int32), Array([2, 3, 5], dtype=int32))

In [15]:
"""
jax.numpy is a high-level api with a familiar interface.
jax.lax is a low-level api that is stricter and often, more powerful.
"""

'\njax.numpy is a high-level api with a familiar interface.\njax.lax is a low-level api that is stricter and often, more powerful.\n'

In [16]:
"""
All jax operations are implemented in terms of operations in XLA
"""

'\nAll jax operations are implemented in terms of operations in XLA\n'

In [17]:
"""
By default, jax executes operations one at a time in sequence
"""

'\nBy default, jax executes operations one at a time in sequence\n'

In [19]:
"""
JIT compilation converts code-blocks into XLA blocks that is run at once
XLA optimizes the code by using
* fusion
* avoiding temp allocation
Must use block_until_ready() when timing

Requires all arrays to have static shapes
"""

'\nJIT compilation converts code-blocks into XLA blocks that is run at once\nXLA optimizes the code by using\n* fusion\n* avoiding temp allocation\nMust use block_until_ready() when timing\n\nRequires all arrays to have static shapes\n'

Jit and other jax transforms work by tracing a function  
Variables you dont want to be traced can be marked as static

When a jitted function is run first time, it is traced and compiled into XLA. When the same function is run again, it is directly run from XLA

In [20]:
def f(x, n):
    c = 0
    for i in range(n):
        c += x
        
    return c

In [22]:
jax.jit(f, static_argnums=[1])(2, 3)

Array(6, dtype=int32, weak_type=True)

In [23]:
jax.jit(f, static_argnames=['n'])(2, 3)

Array(6, dtype=int32, weak_type=True)

In [27]:
jax.make_jaxpr(f, static_argnums=[0, 1])(2,3)

{ lambda ; . let  in (6,) }

Just like values, you can have traced or static operations.

Static operations are evaluated at compile-time in python and traced operations are evaluated at runtime in XLA.

Use numpy for operations that you want to be static, and jax.numpy for traced.

Inputs to a function like reshape must be static in jit.

In [44]:
def f(x):
    x.reshape(x.shape)

In [45]:
jax.jit(f)(jnp.array([1,2,3,4,5,6]))

In [46]:
import numpy as np


In [48]:
@jax.jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

Array([2, 3, 4], dtype=int32)

Observe that jnp.array created a traced array. But numpy operations will give you static values.