* JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API
* NumPy that runs on accelerators


In [10]:
import jax
import jax.numpy as jnp

import numpy as np

In [14]:
np_x = np.arange(10)
print(np_x)

[0 1 2 3 4 5 6 7 8 9]


In [15]:
type(np_x)

numpy.ndarray

In [7]:
jnp_x = jnp.arange(10)
print(jnp_x)

[0 1 2 3 4 5 6 7 8 9]


In [8]:
type(jnp_x)

jaxlib.xla_extension.Array

* Same code can be run on different backends – CPU, GPU and TPU.
* when a JAX function is called the corresponding operation is dispatched to an accelerator to be computed **asynchronously** when possible. 
* This means that if we don’t require the result immediately, the computation won’t block Python execution
* Unless we **block_until_ready** or convert the array to a regular Python type, we will only time the dispatch, not the actual computation

In [16]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

7.15 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
# Makse sure the function JAX operates on does not have side effects
# A side-effect is any effect of a function that doesn’t appear in its output


def in_place_modify(x):
    x[0] = 123
    return None

# This function modifies argument, but returns a completely unrelated value. The modification is a side-effect.

In [18]:
in_place_modify(np_x)

In [19]:
print(np_x)

[123   1   2   3   4   5   6   7   8   9]


In [20]:
# This will result in an error. 
# JAX arrays won’t allow themselves to be modified in-place
# Unlike NumPy arrays, JAX arrays are always immutable
in_place_modify(jnp_x) 

TypeError: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [26]:
def jax_in_place_modify(x):
  return x.at[0].set(123)

y = jax_in_place_modify(jnp_x)

In [27]:
# The old array is untocuhed. No side effects
print(y)
print(jnp_x)

[123   1   2   3   4   5   6   7   8   9]
[0 1 2 3 4 5 6 7 8 9]


## JIT

In [30]:
# Scaled Exponential Linear Unit (SELU), an operation commonly used in deep learning
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

3.96 ms ± 79.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


* XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra 
* It can accelerate TensorFlow models with potentially no source code changes

* In the code above one operation beibg send at a time to the accelerator
* This limits the ability of the XLA compiler to optimize our functions. 

* Optimaly , what we want to do is give the XLA compiler as much code as possible
* This will alloe XLA tofully optimize it. 

In [None]:
# JAX provides the jax.jit transformation, which will JIT compile a JAX-compatible function. 

# defined selu_jit as the compiled version of selu
selu_jit = jax.jit(selu)

# The first run very efficient code optimized for your GPU or TPU
selu_jit(x).block_until_ready()

#time the execution speed of the compiled version
%timeit selu_jit(x).block_until_ready()

## Automatic Vectorization