In [1]:
## Standard libraries
import os
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf') # For export


In [2]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

Using jax 0.4.26


In [3]:
a = jnp.zeros((2,5), dtype=jnp.float32)
print(a)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]




In [4]:
b = jnp.arange(6)
print(b)

[0 1 2 3 4 5]


In [5]:
b.__class__

jaxlib.xla_extension.ArrayImpl

In [6]:
b.device

<bound method ArrayImpl.device of Array([0, 1, 2, 3, 4, 5], dtype=int32)>

In [7]:
b_cpu = jax.device_get(b)
print(b_cpu.__class__)

<class 'numpy.ndarray'>


In [8]:
b_gpu = jax.device_put(b_cpu)
print(f'Device put: {b_gpu.__class__} on {b_gpu.devices()}')

Device put: <class 'jaxlib.xla_extension.ArrayImpl'> on {METAL(id=0)}


In [9]:
b_cpu + b_gpu

Array([ 0,  2,  4,  6,  8, 10], dtype=int32)

In [10]:
b_new = b.at[0].set(1)
print('Original array:', b)
print('Changed array:', b_new)

Original array: [0 1 2 3 4 5]
Changed array: [1 1 2 3 4 5]


In [11]:
rng = jax.random.PRNGKey(42)

In [12]:
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

JAX - Random number 1: -0.9350606
JAX - Random number 2: -0.9350606


In [13]:
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print('NumPy - Random number 1:', np_random_number_1)
print('NumPy - Random number 2:', np_random_number_2)

NumPy - Random number 1: 0.4967141530112327
NumPy - Random number 2: -0.13826430117118466


In [14]:
rng, subkey1, subkey2 = jax.random.split(rng,num=3) 

In [15]:
#no splitting. I get the same numbers always
jax_random_number_0 = jax.random.normal(rng)
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 0:', jax_random_number_0)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)

JAX new - Random number 0: -0.5484217
JAX new - Random number 1: -0.18267898
JAX new - Random number 2: 0.7080024


In [16]:
#every time i run this cell i well get different values
rng, subkey1, subkey2 = jax.random.split(rng,num=3) 
jax_random_number_0 = jax.random.normal(rng)
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 0:', jax_random_number_0)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)

JAX new - Random number 0: -0.26986617
JAX new - Random number 1: -0.4367844
JAX new - Random number 2: -0.082964614


## Tutorial 101

In [17]:
x = jnp.arange(10)
x

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [34]:
long_vector = jnp.arange(int(1e8))
long_vector.__class__

jaxlib.xla_extension.ArrayImpl

In [35]:
%timeit jnp.dot(long_vector, long_vector).block_until_ready()

3.36 ms ± 311 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [36]:
cpu_long_vector = jax.device_get(long_vector)
cpu_long_vector.__class__

numpy.ndarray

In [37]:
%timeit jnp.dot(long_vector, long_vector).block_until_ready()

2.81 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
def sum_of_squares(x):
    return jnp.sum(x**2)

In [44]:
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0,2.0,3.0,4.0])
print(x.__class__)
print(sum_of_squares(x))
print(sum_of_squares_dx(x))

<class 'jaxlib.xla_extension.ArrayImpl'>
30.0
[2. 4. 6. 8.]


In [47]:
def sum_squared_error(x,y):
    return jnp.sum((x-y)**2)
sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1,2.1,3.1,4.1])
print(sum_squared_error(x,y))
print(sum_squared_error_dx(x,y))

0.039999947
[-0.20000005 -0.19999981 -0.19999981 -0.19999981]


In [48]:
jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y

(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

In [51]:
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

In [55]:
inp = jnp.arange(3, dtype=jnp.float32)
inp.__class__

jaxlib.xla_extension.ArrayImpl

In [56]:
print('Input', inp)
print('Output', simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


In [60]:
jax.make_jaxpr(simple_graph)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 3.0
  [34m[22m[1min [39m[22m[22m(f,) }

In [61]:
global_list = []

# Invalid function with side-effect
def norm(x):
    global_list.append(x) #does not contain any operation for this one
    x = x ** 2
    n = x.sum()
    n = jnp.sqrt(n)
    return n

jax.make_jaxpr(norm)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = integer_pow[y=2] a
    c[35m:f32[][39m = reduce_sum[axes=(0,)] b
    d[35m:f32[][39m = sqrt c
  [34m[22m[1min [39m[22m[22m(d,) }

In [62]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


In [63]:
jax.make_jaxpr(grad_function)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = integer_pow[y=1] b
    e[35m:f32[3][39m = mul 2.0 d
    f[35m:f32[3][39m = add c 3.0
    g[35m:f32[][39m = reduce_sum[axes=(0,)] f
    _[35m:f32[][39m = div g 3.0
    h[35m:f32[][39m = div 1.0 3.0
    i[35m:f32[3][39m = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j[35m:f32[3][39m = mul i e
  [34m[22m[1min [39m[22m[22m(j,) }

In [64]:
# when we just want to get the value
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))

In [72]:
jitted_function = jax.jit(simple_graph)

In [73]:
rng, normal_rng = jax.random.split(rng)


In [82]:
large_input = jax.random.normal(normal_rng, (1000,))
_ = jitted_function(large_input)

In [83]:
%%timeit
simple_graph(large_input).block_until_ready()

785 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [84]:
%%timeit
jitted_function(large_input).block_until_ready()

265 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [87]:
jitted_grad_function = jax.jit(grad_function)
_ = jitted_grad_function(large_input) 

In [88]:
%%timeit
grad_function(large_input).block_until_ready()

3.54 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [89]:
%%timeit
jitted_grad_function(large_input).block_until_ready()

253 µs ± 9.36 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
