In [1]:
# Dont run this unless u want to disable GPU

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
import jax
import jax.numpy as jnp
import numpy as np


In [3]:
x = jnp.arange(10)
print(x)


[0 1 2 3 4 5 6 7 8 9]


In [4]:
x

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [7]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

208 µs ± 9.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
# This loop was run when JAX_PLATFORM_NAME was set to "cpu"

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()


7.12 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
long_vec_np = np.arange(int(1e7))

%timeit np.dot(long_vec_np, long_vec_np)

12.7 ms ± 308 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX `grad`

JAX's grad takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function.

In [6]:
def sum_of_squares(x):
  return jnp.sum(x**2)

In [7]:
sum_of_squares_grad = jax.grad(sum_of_squares)

In [8]:
x = jnp.array([1.0, 2.0, 3.0, 4.0])

In [9]:
print(sum_of_squares(x))
print(sum_of_squares_grad(x))

30.0
[2. 4. 6. 8.]


By default `jax.grad` computes the gradient with respect to the first argument of the function.

See the following example

In [10]:
def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)


sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(sum_squared_error_dx(x, y))


[-0.20000005 -0.19999981 -0.19999981 -0.19999981]


To find the gradient with respect to a different argument (or several), you can set `argnums`

In [11]:
jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y

(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

In [12]:
def different_fun(x,y):
    return jnp.sum(x**2) + jnp.sum(y**3)


In [13]:
different_fun_grad = jax.grad(different_fun, argnums=(0, 1))

In [14]:
different_fun(x, y)


DeviceArray(139.304, dtype=float32)

In [15]:
different_fun_grad(x, y)


(DeviceArray([2., 4., 6., 8.], dtype=float32),
 DeviceArray([ 3.63    , 13.229998, 28.829998, 50.43    ], dtype=float32))

In [16]:
different_fun_grad = jax.grad(different_fun, argnums=(0))
different_fun_grad(x, y)

DeviceArray([2., 4., 6., 8.], dtype=float32)

In [17]:
different_fun_grad = jax.grad(different_fun, argnums=(1))
different_fun_grad(x, y)

DeviceArray([ 3.63    , 13.229998, 28.829998, 50.43    ], dtype=float32)

Jax has a handy function to find value and its gradient

In [19]:
jax.value_and_grad(different_fun, argnums=(0,1))(x,y)

(DeviceArray(139.304, dtype=float32),
 (DeviceArray([2., 4., 6., 8.], dtype=float32),
  DeviceArray([ 3.63    , 13.229998, 28.829998, 50.43    ], dtype=float32)))

Jax closely follows numpy's API. One main difference is that, jax is designed to be functional. You might find yourself familar with jax if you are comforatble with functional programming.

In [20]:
# Jax wont allow you to replace array elements like numpy allows

x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = np.array([1.0, 2.0, 3.0, 4.0])

In [22]:
y[1] = 10.0
y

array([ 1., 10.,  3.,  4.])

In [23]:
x[1] = 10.0

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [24]:
x.at[1].set(10.0)

DeviceArray([ 1., 10.,  3.,  4.], dtype=float32)