## **JAX**

JAX is a library for array-oriented numerical computation, with automatic differentiation and JIT compilation to enable high-performance machine learning research

1. JAX provide a unified NumPy-like interface to computations that run on CPU, GPU or TPI, in local or distributed settings,
2. JAX features built-in Jut-in-Time (JIT) compilation, and open source machine learning compiler ecosystem.
3. JAX functions support efficient evalution of gradients via its automatic differentiation transformations.
4. JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs

In [1]:
import jax.numpy as jnp

With the above import, we can immediately start using JAX in a similar manner to NumPy

In [2]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


JAX works great for many numerical and scientific programs, but only if they are written with certain constraints, as explained below:

In [3]:
import numpy as np
from jax import jit
from jax import random
from jax import lax
import jax
import jax.numpy as jnp

### Pure Functions

JAX transformation and compilation are designed to work only on Python functions that are **functionally pure**, i.e all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

Let's see an impure function

In [4]:
def impure_print_side_effect(x):
    print("This is a side effect")
    return x + 1

In [5]:
print(f" First call: {jit(impure_print_side_effect)(4.)}")

This is a side effect
 First call: 5.0


In [6]:
print(f"Second call: {jit(impure_print_side_effect)(5.)}")

Second call: 6.0


In [9]:
print("Third call, different type : ", jit(impure_print_side_effect)(jnp.array([6.])))

This is a side effect
Third call, different type :  [7.]


A Pure function is one L
- Only depends on its inputs
- Returns the same output every time for the same inputs
- Does not produce side effects (like printing, writing files, modifying global state)


Example of a Pure function :

```python
def add(x,y):
    return x+y
```

This is pure, same result every time, no print, not file writing

An impure function is one that 
- Depends on external/global state or randomness
- Produces side effects (like printing, logging, writing to disk)
- Might return different outputs for the same inputs


You can see the example above 

#### **Why Does it Happen?**




1. JAX uses tracing to compile functions, when we use `jit` like 
```python
@jit
def f(x):
    return x+1
```

JAX doesn't just execute `f`, it 
- traces the function with placeholder values (called tracers)
- builds a computation gaph (like XLA IR)
- compiles it into efficient machine code

During this tracing phase, JAX runs your python function once with fake inputs to "record" its operations


2. JAX caches compiled functions

Once it compiles a function for a specific input shape and dtype, it reuses the compiled version next time, this is where functional purity matters.
 - If your function had side effects, they'd only happen once during the trace
 - Later calls (using cached compiled code) would skip them , breaking the logic

In [10]:
def impure_fn(x):
    print("This only prints once — during tracing!")
    return x + 1

# Apply JAX's jit (just-in-time compiler)
fast_fn = jit(impure_fn)

# First call — triggers tracing and compilation
print("First call:", fast_fn(4.0))

# Second call — uses cached compiled version (no tracing)
print("Second call:", fast_fn(5.0))

# Third call — with different type/shape triggers retracing
print("Third call (with array):", fast_fn(jnp.array([6.0])))

This only prints once — during tracing!
First call: 5.0
Second call: 6.0
This only prints once — during tracing!
Third call (with array): [7.]


It's best to keep any `print` or other things outside the `@jit` function

Even for the global values:

In [12]:
g = 0 
def impure_fn_with_global(x):
    print("This only prints once — during tracing!")
    return x + g


print("First call with global:", jit(impure_fn_with_global)(4.0))
g = 1  # Change the global variable
print("Second call with changed global:", jit(impure_fn_with_global)(5.0)) # here technically I should see 6 as the result, but I see 5 instead

# now when we change the type of the input, it retraces
print("Third call with array input:", jit(impure_fn_with_global)(jnp.array([6.0])))
# This will print "This only prints once — during tracing!" again,
# because the input type has changed, so JAX needs to trace the function again.""

This only prints once — during tracing!
First call with global: 4.0
Second call with changed global: 5.0
This only prints once — during tracing!
Third call with array input: [7.]


A python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state

In [13]:
def pure_uses_internal_state(x):
    # This function is pure because it does not read or write external state
    state = dict(even=0, odd=0)
    for _ in range(10):
        state['even' if _ % 2 == 0 else 'odd'] += x
    return state['even'] + state['odd']


print(jit(pure_uses_internal_state)(4.0))  # Should work fine, no side effects

40.0


It is recommended to not use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element, therefore it is incompatible with JAX's functional programming model. 


Ther are some examples of incorrect attempts to use use iterators with JAX. Most return errors, some give unexpected results

In [14]:
import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0
