# Jax Practical

This document served as notes to study the basic knowledge of **Jax**.

The document is not supposed to be a tutorial/bible for the Jax, it is just a good reference of **Jax**, and would be focused on certain functions of Jax which are highly related to author's need.

For a complete tutorial, please refer to [Jax-read-the-doc](https://jax.readthedocs.io/en/latest/installation.html)

In [2]:
import jax.numpy as jnp
import numpy as np
import jax

## Features of Jax

### Pure Function

1. Definition
    -   All the input data is passed through the function parameters, all the results are output through the function results
    -   A pure function will always return the same result if invoked with the same inputs

2. Jax behavior
    -   JAX invokes a cached compilation of the function after the fisrt run
    -   If the input type changes, the function will be rerun
    -   *Iterator* is not competialable with **jax.jit**

In [4]:
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jax.jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jax.jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jax.jit(impure_print_side_effect)(jnp.array([5.])))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


### Out-of-bounds indexing

1. Definition
    -   Index of an array is out of bound

2. Jax Behavior
    -   No error will be returned
    -   The last value of the array will be returned

In [9]:
while True:
    try:
        np.arange(10)[11]
    except IndexError:
        print("Index error!")
        break

# output the last value of the array
jnp.arange(10.0).at[11].get()

Index error!


Array(9., dtype=float32)

### Non-array input

1. Definition
    -   input is list or tuple

2. Jax behavior
    -   return *TypeError*

In [12]:
while True:
    try:
        # jnp do not accept list
        jnp.sum([1, 2, 3])
    except TypeError:
        print("TypeError!")
        break

jnp.sum(np.array([1,2,3]))

TypeError!


Array(6, dtype=int32)

## JIT (Just-In-Time)

### Jit compilation and Caching

1. Pipeline
    -   Define a function and Jit it
    -   Jax compiler compiles the function and cach the XLA code (cost a lot)
    -   Subsequent call reuses the cached code

2. Implemetation tip
    -   Avoid `jit()` on the temporary functions
    -   The function defined in the loop will have new *hash* in each iteration

In [41]:
from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()

jit called in a loop with partials:
258 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
263 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.37 ms ± 33.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### No Jit

1. Conditions when Jit do not work
    -   Condition (*if ... else ...*)
    -   While loop (*while ...*)

2. Explanation
    -   The compilation of the **jit** functions would depend on the type/shape, not value

3. Solution
    -   Avoid the condition on value
    -   use ```jax.lax.cond()```

In [37]:
# Counter example
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

while True:
    try:
        jax.jit(g)(10, 20)  # Raises an error
    except jax.errors.TracerBoolConversionError:
        print("TracerBoolConversionError!")
        break

# Proper way
@jax.jit
def loop_body(prev_i):
  # will only be run once
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 10)

TracerBoolConversionError!


Array(20, dtype=int32, weak_type=True)

## Debugging

1. Rule of thumb
    -   `jax.debug.print()` for the traced array value
    -   `print()` for the static values such as dtypes and array shapes

In [43]:
@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)

# runtime value
@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)

print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
