# $\texttt{jax}$ - getting started, a few take-aways

In [5]:
import jax
import jax.numpy as jnp
import numpy as np

There are already a great number of excellent resources online for getting started with JAX. This notebook is not intended to be a comprehensive introduction to JAX, but rather a collection of tips and tricks that I have found useful in my own work. I hope you find them useful too!

Some selected resources for getting started with JAX:
- [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) (of course!), make **sure** to read the common gotchas [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- [Awesome jax](https://github.com/n2cholas/awesome-jax): a Github repo with links to all sorts of other jax-based repos.
- [get-started-with-jax](https://github.com/gordicaleksa/get-started-with-JAX): a Github repo with a collection of JAX tutorials, from the basics to training neural networks. These cover all things related to jit, vmap, tracer objects,... which I won't discuss here in detail.

## numpy -> jax.numpy works well

In [6]:
def simple_numpy_fn(x):
    return np.sqrt(x)

def simple_jax_fn(x):
    return jnp.sqrt(x)

x = np.array([1, 2, 3])

print(simple_numpy_fn(x))
print(simple_jax_fn(x))

print(type(simple_numpy_fn(x)))
print(type(simple_jax_fn(x)))

[1.         1.41421356 1.73205081]
[1.        1.4142135 1.7320508]
<class 'numpy.ndarray'>
<class 'jaxlib.xla_extension.ArrayImpl'>


## ... but not always

Some examples where jax fails compared to numpy:
1. Item assignments have to be done differently:

In [7]:
# Numpy

print("Numpy")
x = np.array([1, 2, 3])
print(x)
x[0] = 5
print(x)

# Jax
print("jax")
x = jnp.array([1, 2, 3])
print(x)
try:
    x[0] = 5
except Exception as e:
    print(e)
y = x.at[0].set(5)
print(x)

Numpy
[1 2 3]
[5 2 3]
jax
[1 2 3]
'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
[1 2 3]


2. Functions have to be pure

In [8]:
global_variable = 0.0
def fn(x):
  return x + global_variable

fn_jit = jax.jit(fn)

print("Normal function")
print(fn(0))
print(fn(1))
print(fn(2))

print("jitted function")
print(fn_jit(0))
print(fn_jit(1))
print(fn_jit(2))

global_variable = 100.0
print(f"Now, gobal variable = {global_variable}")
print("Normal function")
print(fn(0))
print(fn(1))
print(fn(2))

print("jitted function")
print(fn_jit(0)) # jit function is not recompiled, uses cached global_variable value
print(fn_jit(1))
print(fn_jit(2))

Normal function
0.0
1.0
2.0
jitted function
0.0
1.0
2.0
Now, gobal variable = 100.0
Normal function
100.0
101.0
102.0
jitted function
0.0
1.0
2.0


3. Random numbers are a bit annoying

In [9]:
# numpy
print("Numpy")

print(np.random.random())
print(np.random.random())
print(np.random.random())

# jax
key = jax.random.PRNGKey(42) # have to keep track of the key

print("This doesn't work")
print(jax.random.normal(key, shape=(1,)))
print(jax.random.normal(key, shape=(1,)))
print(jax.random.normal(key, shape=(1,)))

print("This does!")
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))

print("Note: the random seed is therefore fixed (rerun this cell to see it)")

Numpy
0.13770181082435917
0.18422509633936113
0.13787038113921923
This doesn't work
[-0.18471177]
[-0.18471177]
[-0.18471177]
This does!
[1.3694694]
[-0.19947024]
[-2.2982783]
Note: the random seed is therefore fixed (rerun this cell to see it)


Control flow is annoying as well

In [10]:
def my_clip(number):
    if number < 0:
        return 0
    else:
        return 1

In [19]:
print(my_clip(-1))
print(my_clip(5))

# make it faster with jit!!! ... or not :(

my_clip_jit = jax.jit(my_clip)

try:
    print(my_clip_jit(-1))
    print(my_clip_jit(5))
except Exception as e:
    print(e)

0
1
Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function my_clip at /var/folders/xj/6sk96vdn15385fpjn9nd1rcw0000gp/T/ipykernel_25333/519351608.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument number.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError


Solution

In [21]:
@jax.jit
def my_clip_correct_jit(number):
    def less_than_zero(number):
        return jnp.array(0.)

    def greater_than_zero(number):
        return jnp.array(1.)

    return jax.lax.cond(number < 0, less_than_zero, greater_than_zero, operand=number)

print(my_clip_correct_jit(-1))
print(my_clip_correct_jit(5))

0.0
1.0


Smarter solution:

In [None]:
@jax.jit
def my_clip(number):
    return jnp.clip(number, 0, 1)

# $\texttt{jax}$ sharp bits

Make sure to read [the sharp bits page of jax](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) and also the [FAQ page](https://jax.readthedocs.io/en/latest/faq.html) in order to learn how to code in jax and avoid common pitfalls.

## NaNs in gradients with `jnp.where`

There is some obscure bug that seems to be bugging me right now (pun not intended). The problem I have can be described as follows.

I have some code with gradients that works fine. But once the code is jitted, things become problematic and I get NaNs. The NaNs escape me, but they are flagged when the jax config is set to `jax.config.update('jax_debug_nans', True)`. The NaNs are not flagged when the code is not jitted. 

After some debugging and Googling, the problem seems to come from the use of `jnp.where`. See the FAQ page for information.

Let's see a simple example that reproduces the problem.

In [12]:
def my_log(x: float):
    """
    Compute the logarithm but make sure the entry is positive, for which we use where.

    Args:
        x (float): A number, hopefully positive

    Returns:
        float: The log of the number, hopefully
    """
    return jnp.where(x > 0., jnp.log(x), 0.)

In [13]:
print(my_log(0.)) # OK
print(jax.grad(my_log)(0.)) # NaN

0.0
nan


From the FAQ documentation page that discusses this ([link](https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where)), here is a short explanation:
>A short explanation is that during grad computation the adjoint corresponding to the undefined jnp.log(x) is a NaN and it gets accumulated to the adjoint of the jnp.where. The correct way to write such functions is to ensure that there is a jnp.where inside the partially-defined function, to ensure that the adjoint is always finite:

In [14]:
def safe_for_grad_log(x):
    return jnp.log(jnp.where(x > 0., x, 1.))

In [15]:
print(safe_for_grad_log(0.)) # OK!
print(jax.grad(safe_for_grad_log)(0.)) # OK!

0.0
0.0


Another example:

In [16]:
def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

In [17]:
print(my_log_or_y(0., 5.)) # OK!
print(jax.grad(my_log_or_y)(0., 5.)) # OK!

5.0
0.0
