## JAX Basics

### JAX API Layering

Like most libraries, JAX follows the "onion" structure/layering<br><br>
- NumPy <--> lax <--> XLA
- The lax API is stricter and more powerful
- It is a python wrapper around XLA

In [12]:
import jax.numpy as jnp
import numpy as np
import jax

from jax import grad, jit, pmap, vmap
from jax import lax

In [13]:
seed = 0
key = jax.random.PRNGKey(seed)

#### lax API is stricter

In [4]:
print(jnp.add(1,1.0))
# Implicitly promotes mixed types

2.0


In [6]:
print(lax.add(1,1.0))
# Needs explicit conversion

TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

#### lax is more powerful but less user-friendly

In [10]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)

# NumPy API
result1 = jnp.convolve(x, y)

# lax API
result2 = lax.conv_general_dilated(
    x.reshape(1,1,3).astype(float), # Explicit Promotion
    y.reshape(1,1,10),
    window_strides=(1,),
    padding=[(len(y)-1, len(y)-1)]    # Equivalent to padding="full" in NumPy
)

print(result1)
print(result2[0][0])

assert(np.allclose(result1,result2[0][0], atol=1e-6))

[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]


### JIT Functions are faster

In [15]:
def norm(x):
    x = x - x.mean(0)
    return x/x.std(0)

norm_compiled = jit(norm)

x = jax.random.normal(key, (10000, 100), dtype=jnp.float32)

assert np.allclose(norm(x), norm_compiled(x), atol=1e-6)

In [16]:
print("Normal Function")
%timeit norm(x).block_until_ready()

print("JIT Compiled Function")
%timeit norm_compiled(x).block_until_ready()

Normal Function
1.41 ms ± 25.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
JIT Compiled Function
730 μs ± 9.81 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### 1) Example of Failure: Array shapes must be static 

In [18]:
def get_negatives(x):
    return x[x<0]

x = jax.random.normal(key, (10,), dtype=jnp.float32)
print(get_negatives(x))

[-0.43359444 -0.07861735 -0.97208923 -0.49529874 -0.9501635 ]


In [19]:
print(jit(get_negatives)(x))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

### So how does JIT work in the background?

In [22]:
@jit
def f(x,y):
    print("Running f():")
    print(f"x = {x}")
    print(f"y = {y}")
    result = jnp.dot(x+1, y+1)
    print(f"Result: x.y = {result}")
    return result

x = np.random.randn(3,4)
y = np.random.randn(4)
print(f(x,y))

Running f():
x = JitTracer<float32[3,4]>
y = JitTracer<float32[4]>
Result: x.y = JitTracer<float32[3]>
[6.2779365 2.786862  2.5150263]


In [23]:
x2 = np.random.randn(3,4)
y2 = np.random.randn(4)
print("Second Time (But actually the Third Call)")
print(f(x2, y2))

Second Time (But actually the Third Call)
[4.434293  3.7576451 1.2664586]


- Side Effects (Like print statements) are not compiled 
- The first time you run the jit function, it runs a "trace" in the background
- It will not input the actual values of x and y when tracing
- It creates abstract tracer values - these are placeholder values with specific shape and data type
- This is what was printed - JitTracer<>
- This helps JIT understand what type and shape of values are inputted in the function, how they are morphed, and what shape and data type the output has
- The second time the function is called, JIT will ignore all the side effects

In [24]:
x3 = np.random.randn(4,5)
y3 = np.random.randn(5)
print(f(x3, y3))

Running f():
x = JitTracer<float32[4,5]>
y = JitTracer<float32[5]>
Result: x.y = JitTracer<float32[4]>
[4.5076585 4.263817  5.343482  5.9992876]


- This time JIT compiled the function again
- Because the shape of the arrays changed
- JIT is smart enough to retrace

#### Same function but no side effects

In [27]:
def f(x, y):
    return jnp.dot(x+1, y+1)

print(jax.make_jaxpr(f)(x, y))
# make_jaxpr: Make Jax Expression

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0:f32[]
    d:f32[4] = add b 1.0:f32[]
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }


#### This JAX Expression is created when JIT Traces the function in the background

#### 2) Another Example of JIT Failure: Dependencies/Concrete Values

In [29]:
@jit
def f(x, neg):
    return -x if neg else x

f(1, True)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at C:\Users\gr8my\AppData\Local\Temp\ipykernel_34648\2013549118.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

#### Workaround: Static Arguments

In [30]:
from functools import partial

In [31]:
@partial(jit, static_argnums=(1,))
def f(x, neg):
    print(x)
    return -x if neg else x

print(f(1, True))

JitTracer<~int32[]>
-1


In [33]:
print(f(2, True))

-2


In [34]:
print(f(2, False))

JitTracer<~int32[]>
2


In [35]:
print(f(3, False))

3


- By making an argument static, JIT will not use the abstract tracer object for that argument. It will use the actual value
- So we are lowering the level of abstraction while tracing
- Everytime the value of that static argument changes, the tracing will be triggered

#### 3) Another Failure: Traced Object gets passed to a function which expects a concrete value

In [37]:
@jit
def f(x):
    print(x)
    print(x.shape)
    print(jnp.array(x.shape).prod())
    return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2,3))
f(x)

JitTracer<float32[2,3]>
(2, 3)
JitTracer<int32[]>


TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer<int32[]>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at C:\Users\gr8my\AppData\Local\Temp\ipykernel_34648\3784783456.py:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = reduce_prod[axes=(0,)] b
    from line C:\Users\gr8my\AppData\Local\Temp\ipykernel_34648\3784783456.py:6:21 (f)

#### Workaround: Use numpy instead of jax.numpy

In [38]:
@jit
def f(x):
    return x.reshape((np.prod(x.shape),))

f(x)

Array([1., 1., 1., 1., 1., 1.], dtype=float32)

## Pure Functions <br>
- JAX is designed to work on Pure Functions


#### What are pure functions? (Informal Defination): <br>
- All the input data is passed through the function parameters and the results are output through the function results
- A pure function will always return the same result if invoked with the same inputs

#### Example 1

In [41]:
def impure_print_side_effect(x):
    # Violating #1
    print("Executing Function")

    return x

In [42]:
# Side Effects will appear during the first run
print("First Call: ", jit(impure_print_side_effect)(4.))

Executing Function
First Call:  4.0


In [43]:
# Subsequent runs with parameters of same type and shape will not show these side effects
# This is because JAX now invokes a cached, compiled version of the function
print("Second Call: ", jit(impure_print_side_effect)(6.))

Second Call:  6.0


In [44]:
# JAX will rerun the trace when the type or shape changes
print("Third Call: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing Function
Third Call:  [5.]


#### Example 2

In [45]:
g = 0
def impure_uses_global(x):
    # Violates #1 and #2
    return x+g

In [46]:
# JAX will capture the value of the global variable during first run
print("First Call: ", jit(impure_uses_global)(4.))

First Call:  4.0


In [47]:
# Subsequent calls will silently use the cached value of global var
g = 10
print("Second Call: ", jit(impure_uses_global)(4.))

Second Call:  4.0


In [49]:
# Only if JAX has to rerun the trace, the value will get updated
print("Third Call: ", jit(impure_uses_global)(jnp.array([5.])))

Third Call:  [15.]


#### Example 3: Haiku/Flax are built upon this idea

In [None]:
def pure_uses_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state["even" if i%2==0 else "odd"] += x
    return state['even']+state['odd']

# Altho this a loop,
# Nothing is being violated

In [51]:
print(jit(pure_uses_internal_state)(5.))

50.0


In [53]:
print(jit(pure_uses_internal_state)(2.))

20.0


#### Example 4: No Iterators! Iterators are stateful

In [55]:
# Using lax.for_iloop
# Similar for lax.scan, lax.cond, etc

# fori_loop(lower, upper, body_func, init_val)

# Correct usage
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0))
# Expected value 45

# Wrong usage
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0))
# Unexpected value 0

45
0
