In [1]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)
print("Available Device", jax.devices())

Using jax 0.4.5
Available Device [CpuDevice(id=0)]


# Basic

jnp will automatically create tensor in GPU by default

In [2]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(a)
print(a.__class__)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
<class 'jaxlib.xla_extension.Array'>


In [3]:
a.device()  # GpuDevice if you have GPU

CpuDevice(id=0)

In [4]:
a_cpu = jax.device_get(a)
print(a_cpu.__class__)

# a_acc = jax.device_put(a)
a_acc = jax.device_put(a_cpu)
print(a_acc.__class__)

<class 'numpy.ndarray'>
<class 'jaxlib.xla_extension.Array'>


In [5]:
a_cpu + a_acc

Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [6]:
new_a = a.at[1].set(1)
print(a, '\n\n', new_a)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]] 

 [[0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1.]]


# Jaxpr

In [7]:
@jax.jit
def simple_graph(x):
    x = x + 2
    # print(x)
    x = x ** 2
    # print(x)
    x = x + 3
    # print(x)
    y = x.mean()
    return y

inp = jnp.arange(3, dtype=jnp.float32)
print('Input', inp)
print('Output', simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


view jaxpr of this function

In [8]:
jax.make_jaxpr(simple_graph)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = add c 3.0
    e:f32[] = reduce_sum[axes=(0,)] d
    f:f32[] = div e 3.0
  in (f,) }

In [9]:
global_list = []

# Invalid function with side-effect
def norm(x):
    global_list.append(x)
    x = x ** 2
    n = x.sum()
    n = jnp.sqrt(n)
    return n

print(global_list)
jax.make_jaxpr(norm)(inp)  # no 'append' in jaxpr

[]


{ lambda ; a:f32[3]. let
    b:f32[3] = integer_pow[y=2] a
    c:f32[] = reduce_sum[axes=(0,)] b
    d:f32[] = sqrt c
  in (d,) }

# Autograd in Jax

In [10]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


In [11]:
jax.make_jaxpr(grad_function)(inp)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 2.0
    c:f32[3] = integer_pow[y=2] b
    d:f32[3] = integer_pow[y=1] b
    e:f32[3] = mul 2.0 d
    f:f32[3] = add c 3.0
    g:f32[] = reduce_sum[axes=(0,)] f
    _:f32[] = div g 3.0
    h:f32[] = div 1.0 3.0
    i:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j:f32[3] = mul i e
  in (j,) }

In [12]:
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))