# jax - getting started, a few take-aways

In [91]:
import jax
import jax.numpy as jnp
import numpy as np

There are already a great number of excellent resources online for getting started with JAX. This notebook is not intended to be a comprehensive introduction to JAX, but rather a collection of tips and tricks that I have found useful in my own work. I hope you find them useful too!

Some selected resources for getting started with JAX:
- [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) (of course!), make **sure** to read the common gotchas [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- [Awesome jax](https://github.com/n2cholas/awesome-jax): a Github repo with links to all sorts of other jax-based repos.
- [get-started-with-jax](https://github.com/gordicaleksa/get-started-with-JAX): a Github repo with a collection of JAX tutorials, from the basics to training neural networks. These cover all things related to jit, vmap, tracer objects,... which I won't discuss here in detail.

## numpy -> jax.numpy works well

In [92]:
# Write some very basic function implementing numpy functionalities, then show the same can be done with jax.numpy

def simple_numpy_fn(x):
    return np.sqrt(x)

def simple_jax_fn(x):
    return jnp.sqrt(x)

x = np.array([1, 2, 3])

print(simple_numpy_fn(x))
print(simple_jax_fn(x))

print(type(simple_numpy_fn(x)))
print(type(simple_jax_fn(x)))

[1.         1.41421356 1.73205081]
[1.        1.4142135 1.7320508]
<class 'numpy.ndarray'>
<class 'jaxlib.xla_extension.ArrayImpl'>


## ... but not always

Some examples where jax fails compared to numpy:
1. Item assignments have to be done differently:

In [93]:
# Numpy

print("Numpy")
x = np.array([1, 2, 3])
print(x)
x[0] = 5
print(x)

# Jax
print("jax")
x = jnp.array([1, 2, 3])
print(x)
try:
    x[0] = 5
except Exception as e:
    print(e)
y = x.at[0].set(5)
print(x)

Numpy
[1 2 3]
[5 2 3]
jax
[1 2 3]
'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
[1 2 3]


2. Functions have to be pure

In [94]:
global_variable = 0.0
def fn(x):
  return x + global_variable
fn_jit = jax.jit(fn)
print("Normal function")
print(fn(0))
print(fn(1))
print(fn(2))

print("jitted function")
print(fn_jit(0))
print(fn_jit(1))
print(fn_jit(2))

global_variable = 100.0
print(f"Now, gobal variable = {global_variable}")
print("Normal function")
print(fn(0))
print(fn(1))
print(fn(2))

print("jitted function")
print(fn_jit(0)) # jit function is not recompiled, uses cached global_variable value
print(fn_jit(1))
print(fn_jit(2))

Normal function
0.0
1.0
2.0
jitted function
0.0
1.0
2.0
Now, gobal variable = 100.0
Normal function
100.0
101.0
102.0
jitted function
0.0
1.0
2.0


3. Random numbers are a bit annoying

In [95]:
# numpy
print("Numpy")

print(np.random.random())
print(np.random.random())
print(np.random.random())

# jax
key = jax.random.PRNGKey(42) # have to keep track of the key

print("This doesn't work")
print(jax.random.normal(key, shape=(1,)))
print(jax.random.normal(key, shape=(1,)))
print(jax.random.normal(key, shape=(1,)))

print("This does!")
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey, shape=(1,)))

print("Note: the random seed is therefore fixed (rerun this cell to see it)")

Numpy
0.5362616646686718
0.4502403097327451
0.9086285597347172
This doesn't work
[-0.18471177]
[-0.18471177]
[-0.18471177]
This does!
[1.3694694]
[-0.19947024]
[-2.2982783]
Note: the random seed is therefore fixed (rerun this cell to see it)
