# $\texttt{jax}$: my sharp bits

**Abstract**: There are a lot of sharp bits already covered in the official documentation (look [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)), but I found some of them are not so easy to find. So I decided to write this note to record some of them that I personally came across and had issues with debugging.

In [3]:
import jax
import jax.numpy as jnp
import numpy as np

# NaNs in gradients and `jnp.where`

There is some obscure bug that seems to be bugging me right now (pun not intended). The problem I have can be described as follows.

I have some code with gradients that works fine. But once the code is jitted, things become problematic and I get NaNs. The NaNs escape me, but they are flagged when the jax config is set to `jax.config.update('jax_debug_nans', True)`. The NaNs are not flagged when the code is not jitted. 

After some debugging and Googling, the problem seems to come from the use of `jnp.where`. Let's see a simple example that reproduces the problem.

In [4]:
def my_log(x: float):
    """
    Compute the logarithm but make sure the entry is positive, for which we use where.

    Args:
        x (float): A number, hopefully positive

    Returns:
        float: The log of the number, hopefully
    """
    return jnp.where(x > 0., jnp.log(x), 0.)

In [5]:
my_log(0.) # works OK

Array(0., dtype=float32, weak_type=True)

In [6]:
jax.grad(my_log)(0.) # gives NaN

Array(nan, dtype=float32, weak_type=True)

From the FAQ documentation page that discusses this ([link](https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where)):
>A short explanation is that during grad computation the adjoint corresponding to the undefined jnp.log(x) is a NaN and it gets accumulated to the adjoint of the jnp.where. The correct way to write such functions is to ensure that there is a jnp.where inside the partially-defined function, to ensure that the adjoint is always finite:

In [7]:
def safe_for_grad_log(x):
    return jnp.log(jnp.where(x > 0., x, 1.))

In [8]:
print(safe_for_grad_log(0.)) # OK!
print(jax.grad(safe_for_grad_log)(0.)) # OK!

0.0
0.0


The documentation page also mentions you might need another where (note that the documentation page has a syntax error in that cell):

In [9]:
def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

In [10]:
print(my_log_or_y(0., 5.)) # OK!
print(jax.grad(my_log_or_y)(0., 5.)) # OK!

5.0
0.0
