# Jax basics

This is a specific notebook for me to learn [JAX](https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html)

In [1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

In [18]:
from jax import make_jaxpr, jit, grad, jacobian, jacfwd, jacrev, hessian, random

from functools import partial

Jax cannot set items from array directly, must use `at[num].set`

In [3]:
# In numpy
x = np.arange(10)
x[0] = 10
x

array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9])

In [4]:
# In Jax
y = jnp.arange(10)
y.at[0].set(10)

Array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9], dtype=int32)

In [5]:
y.devices(), y.sharding

({CudaDevice(id=0)},
 SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device))

Jax performance against numpy

In [6]:
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

acc_norm = jit(norm)
np.random.seed(42)

X = jnp.array(np.random.rand(10000,10))

In [7]:
%%time
res_np = norm(X)

CPU times: user 111 ms, sys: 8.85 ms, total: 120 ms
Wall time: 597 ms


In [8]:
%%time 
res_jax = acc_norm(X)

CPU times: user 61.4 ms, sys: 1.36 ms, total: 62.7 ms
Wall time: 94.4 ms


In [9]:
np.allclose(res_np,res_jax,atol=1E-6)

True

In [10]:
%timeit norm(X).block_until_ready()
%timeit acc_norm(X).block_until_ready()

93.2 μs ± 1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
34.9 μs ± 696 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


First time the function is called it trigger the compilation of the function, and it only compiles JAX code, so prints are not compiled to JAX and therefore not showned the second time the function runs. `make_jaxpr` shows the underlying code of the function

Note that you cannot create conditional functions or with changing size. Only using partial you can make it, but results on re-compilation.

In [11]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  {x = }")
  print(f"  {y = }")
  result = jnp.dot(x + 1, y + 1)
  print(f"  {result = }")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<float32[3,4]>with<DynamicJaxprTrace>
  y = Traced<float32[4]>with<DynamicJaxprTrace>
  result = Traced<float32[3]>with<DynamicJaxprTrace>


Array([ 0.02217698,  4.49119   , -0.18381482], dtype=float32)

In [12]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

Array([4.6115603, 6.641762 , 6.2128515], dtype=float32)

In [13]:
make_jaxpr(f)(x,y)

{ [34;1mlambda [39;22m; a[35m:f32[3,4][39m b[35m:f32[4][39m. [34;1mlet
    [39;22mc[35m:f32[3][39m = pjit[
      name=f
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[3,4][39m b[35m:f32[4][39m. [34;1mlet
          [39;22md[35m:f32[3,4][39m = add a 1.0:f32[]
          e[35m:f32[4][39m = add b 1.0:f32[]
          c[35m:f32[3][39m = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] d e
        [34;1min [39;22m(c,) }
    ] a b
  [34;1min [39;22m(c,) }

In [14]:
@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)

Array(-1, dtype=int32, weak_type=True)

# Jax auto gradient

In [15]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


You can mix and match `grad` with `jit`, the `grad` works similarly to [autograd](https://github.com/HIPS/autograd)

In [16]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.0353256


Jax has multiple variation of the jacobian, like `jacobian`, `jacrev`, `jacfwd`, etc. You can also use the `hessian`

In [19]:
jacobian(jnp.exp)(x_small)

Array([[1.       , 0.       , 0.       ],
       [0.       , 2.7182817, 0.       ],
       [0.       , 0.       , 7.389056 ]], dtype=float32)

In [21]:
def hessian_m(fun):
    return jit(jacfwd(jacrev(fun)))

np.allclose(hessian_m(sum_logistic)(x_small),hessian(sum_logistic)(x_small), atol=1E-6)

True

Just like 

In [None]:
key = random.key(1701)
key1, key2 = random.split(key)
mat = random.normal(key1, (150,100))
batched_x = random.normal(key2, (10,100))

def apply_matrix(x):
    return jnp.dot(mat,x)

In [None]:
def v1_batch_apply(v_batched):
    return jnp.stack([[apply_matrix(v) for v in v_batched]])

def v2_batch_apply(v_batched):
    return 
