<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="300" height="300" align="center"/><br>

Welcome to another JAX tutorial. I hope you all have been enjoying the JAX Tutorials so far. We have already completed three tutorials on JAX each of which introduced an important concept. 

In the first tutorial, we discussed **DeviceArray**, the core Data Structure in JAX. In the second tutorial, we looked into **Pure Functions** and their pros and cons. In the third tutorial, we looked into **Pseudo-Random Number Generation** in JAX, and how they are different from Numpy's PRNG. If you haven't gone through the previous tutorials, I highly suggest going through them. Here are the links:

1. [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
3. [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)
4. [TF_JAX_Tutorials - Part 4 (JAX and DeviceArray)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray)
5. [TF_JAX_Tutorials - Part 5 (Pure Functions in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-5-pure-functions-in-jax/)
6. [TF_JAX_Tutorials - Part 6 (PRNG in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/)


Today we will look into another important concepts: **Just In Time Compilation (JIT)** in JAX

# What is Just In Time (JIT) Compilation?

If we go by the [definition](https://en.wikipedia.org/wiki/Just-in-time_compilation) of JIT, then JIT is a way of compiling your code during the execution. A system implementing a JIT compiler typically continuously analyses the code being executed and identifies parts of the code where the speedup gained from compilation or recompilation would outweigh the overhead of compiling that code.


# JIT in JAX

As we discussed in the first chapter that JAX uses XLA for compilation. The `jax.jit(...)` transform does the just-in-time compilation and **transforms** your normal JAX Python functions so that they can be executed **more efficiently** in XLA. 
Let's see a few examples of it before discussing the details

In [None]:
import os
import time
import requests

import jax
import jax.numpy as jnp
from jax import jit, grad, random

from jax.config import config
%config IPCompleter.use_jedi = False

In [4]:
def apply_activation(x):
    return jnp.maximum(0.0, x)

def get_dot_product(W, X):
    return jnp.dot(W, X)

In [5]:
# Always use a seed
key = random.PRNGKey(1234)
W = random.normal(key=key, shape=[1000, 10000], dtype=jnp.float32)

# Never reuse the key
key, subkey = random.split(key)
X = random.normal(key=subkey, shape=[10000, 20000], dtype=jnp.float32)

# JIT the functions we have
dot_product_jit  = jit(get_dot_product)
activation_jit = jit(apply_activation)

for i in range(3):
    start = time.time()
    # Don't forget to use `block_until_ready(..)`
    # else you will be recording dispatch time only
    Z = dot_product_jit(W, X).block_until_ready()
    end = time.time()
    print(f"Iteration: {i+1}")
    print(f"Time taken to execute dot product: {end - start:.2f} seconds", end="")
    
    start = time.time()
    A = activation_jit(Z).block_until_ready()
    print(f", activation function: {time.time()-start:.2f} seconds")

Iteration: 1
Time taken to execute dot product: 3.55 seconds, activation function: 0.06 seconds
Iteration: 2
Time taken to execute dot product: 3.57 seconds, activation function: 0.03 seconds
Iteration: 3
Time taken to execute dot product: 3.87 seconds, activation function: 0.03 seconds


Let's break down the above example into steps to know in detail what happened under the hood.

1. We defined two functions namely, `get_dot_product(...)` that does a dot product of weights and the inputs, and `apply_activation(...)` that applies `relu` on the previous result.
2. We then defined two transformations using `jit(function_name)`, and got the **compiled** versions of our functions
3. When you call the compiled function for the first time with the specified arguments, the execution time is pretty high. Why? Because the first call serves as the `warmup` phase. The warmup phase is nothing but the time taken by JAX **tracing**. Depending on the inputs, the tracers convert the code into an intermediate language, **`jaxprs`** (we will talk about this in a bit) which, is then compiled for execution in XLA
4. The subsequent calls run the compiled version of the code

**Note:** If you are benchmarking `jit` version of your function with something else, do a warmup first for a fair comparison else you will include the compilation time in the benchmarks

Before continuing further on JIT transformations, we will take a break here and try to understand the concept of **`jaxprs`** first

# Jaxprs

Jaxpr is an intermediate language for representing the normal Python functions. When you transform a function the function is first converted to simple statically-typed intermediate expressions by Jaxpr language, then the transformations are directly applied on these jaxprs. 

1. A jaxpr instance represents a function with one or more typed parameters (input variables) and one or more typed results
2. The inputs and outputs have `types` and are represented as abstract values
3. Not all Python programs can be represented by jaxprs but many scientific computations and machine learning programs can


## Should you learn about Jaxprs?

Every transformation in JAX materializes to some form of `jaxpr`. If you want to understand how JAX works internally, or if you want to understand the result of JAX tracing, then yes, it is useful to understand jaxprs.


Let's take a few examples of how jaxpr works. We will first see how the functions we defined above are expressed by jaxpr

In [6]:
# Make jaxpr for the activation function
print(jax.make_jaxpr(activation_jit)(Z))

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = max 0.0 a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=apply_activation ] a
  in (b,) }


How to interpret this jaxpr?

1. The first line tells you that the function receives one argument `a`
2. The second line tells you that this is what would be executed on XLA, the max of (0, `a`)
3. The last line tells you the output being returned

Let's look at the jaxpr of our function that applies dot product

In [7]:
# Make jaxpr for the activation function
print(jax.make_jaxpr(dot_product_jit)(W, X))

{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                                                      precision=None
                                                      preferred_element_type=None ] a b
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    inline=False
                    name=get_dot_product ] a b
  in (c,) }


Simlar to above, here:
1. The first line is telling that the function receives two input variables `a` and `b`, corresponding to our `W` and `X`
2. The second line is an XLA call where we perform the dot operation. (Check the dimesions numbers used for dot product)
3. The last line is the result to be returned denoted by `c` 


Let's take another interesting example

In [8]:
# We know that `print` introduces but impurity but it is
# also very useful to print values while debugging. How does
# jaxprs interpret that?

def number_squared(num):
    print("Received: ", num)
    return num ** 2

# Compiled version
number_squared_jit = jit(number_squared)

# Make jaxprs
print(jax.make_jaxpr(number_squared_jit)(2))

Received:  Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>
{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = integer_pow[ y=2 ] a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=number_squared ] a
  in (b,) }


Notice how the `num` inside the print statement is traced. Nothing stops you from running an impure function but you should be ready to encounter such side effects. The fact that the print statement is traced on the first call but may not be on the subsequent calls is because your python code will run at least once. Let's see that in action as well

In [9]:
# Subsequent calls to the jitted function
for i, num in enumerate([2, 4, 8]):
    print("Iteration: ", i+1)
    print("Result: ", number_squared_jit(num))
    print("="*50)

Iteration:  1
Received:  Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Result:  4
Iteration:  2
Result:  16
Iteration:  3
Result:  64


We will take one more example to appreciate the beauty of `jaxprs` before moving on to JIT again

In [10]:
squared_numbers = []

# An impure function (using a global state)
def number_squared(num):
    global squared_numbers
    squared = num ** 2
    squared_numbers.append(squared)
    return squared

# Compiled verison
number_squared_jit = jit(number_squared)

# Make jaxpr
print(jax.make_jaxpr(number_squared_jit)(2))

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = integer_pow[ y=2 ] a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=number_squared ] a
  in (b,) }


A few things to notice:
1. The first line stats as usual and shows that we have an input variable `a`, corresponding to the `num` argument
2. The second line is an XLA call that squares the input number.
3. The last line returns the results of the XLA call denoted by `b`

**The side effect isn't captured by jaxpr**. jaxpr depends on **`tracing`**. The behavior of any transformed function is dependent on the traced values. You may notice the side effect on the first run but not necessarily on the subsequent calls. Hence jaxpr isn't even bothered about the global list in this case. 

**Note:** One more important thing to note is the `device` value in the jaxprs. Although this argument is there unless you specify the device during jit transform like this `jit(fn_name, device=)`, no device would be listed here. This can be confusing sometimes because your computation would be running on some accelerator but here the device name won't be reflected. The logic behind this is that jaxpr is just an expression, independent of the logic where it is going to run. It is more concerned about the layout of the representation for XLA rather than the device on which the expression will be made to run  

In [11]:
# Subsequent calls to the jitted function
for i, num in enumerate([4, 8, 16]):
    print("Iteration: ", i+1)
    print("Result: ", number_squared_jit(num))
    print("="*50)
    
# What's in the list?
print("\n Results in the global list")
squared_numbers

Iteration:  1
Result:  16
Iteration:  2
Result:  64
Iteration:  3
Result:  256

 Results in the global list


[Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>,
 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>]

You might be wondering that if the side effect was to appear on the first call, why there are two traced values in the global list. The reason is that the side effect may or may not appear on the subsequent calls. It is an unpredictable behavior.

# How much to JIT?

Before diving into the nuances related to JIT, let's assume that you have two functions that can be jitted with no problems, for example, our `get_dot_product(...)` and `apply_activation(..)` functions. Should you jit them both or should you use them into one function or module and jit that function/module? Le's see that in action

In [12]:
# Calling the two functions into a single function
# so that we can jit this function instead of jitting them
def forward_pass(W, X):
    Z = get_dot_product(W, X)
    A = apply_activation(Z)
    return Z, A



# Always use a seed
key = random.PRNGKey(1234)

# We will use much bigger array this time
W = random.normal(key=key, shape=[2000, 10000], dtype=jnp.float32)

# Never reuse the key
key, subkey = random.split(key)
X = random.normal(key=subkey, shape=[10000, 20000], dtype=jnp.float32)

# JIT the functions we have individually
dot_product_jit  = jit(get_dot_product)
activation_jit = jit(apply_activation)

# JIT the function that wraps both the functions
forward_pass_jit = jit(forward_pass)

for i in range(3):
    start = time.time()
    # Don't forget to use `block_until_ready(..)`
    # else you will be recording dispatch time only
    Z = dot_product_jit(W, X).block_until_ready()
    end = time.time()
    print(f"Iteration: {i+1}")
    print(f"Time taken to execute dot product: {end - start:.2f} seconds", end="")
    
    start = time.time()
    A = activation_jit(Z).block_until_ready()
    print(f", activation function: {time.time()- start:.2f} seconds")
    
    # Now measure the time with a single jitted function that calls
    # the other two functions
    Z, A = forward_pass_jit(W, X)
    Z, A = Z.block_until_ready(), A.block_until_ready()
    print(f"Time taken by the forward pass function: {time.time()- start:.2f} seconds")
    print("")
    print("="*50)

Iteration: 1
Time taken to execute dot product: 6.88 seconds, activation function: 0.09 seconds
Time taken by the forward pass function: 7.08 seconds

Iteration: 2
Time taken to execute dot product: 6.82 seconds, activation function: 0.07 seconds
Time taken by the forward pass function: 7.21 seconds

Iteration: 3
Time taken to execute dot product: 7.21 seconds, activation function: 0.06 seconds
Time taken by the forward pass function: 6.98 seconds



Which approach to follow? That's up to you. Also, I don't have a confirmation whether the second approach always works but a Twitter user, who is a heavy JAX user, pointed this out.


# JIT and Python Control Flow

A natural question that comes to mind at this stage is `Why don't we just JIT everything? That would give a massive gain in terms of execution`. Though true in some sense, you can't jit everything. There are certain scenarios where jitting wouldn't work out of the box. Let's take a few examples to understand this

In [13]:
def square_or_cube(x):
    if x % 2 == 0:
        return x ** 2
    else:
        return x * x * x


# JIT transformation
square_or_cube_jit = jit(square_or_cube)

# Run the jitted version on some sample data
try:
    val = square_or_cube_jit(2)
except Exception as ex:
    print(type(ex).__name__, ex)

ConcretizationTypeError Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function square_or_cube at <ipython-input-13-d8fbb25d44c9>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


So why this code didn't work? Let's break down the whole process of JIT once again, including the one we have here

1. When we `jit` a function, we aim to get a compiled version of that function, so that we can cache and reuse the compiled code for different values of the arguments. 
2. To achieve this, JAX traces it on abstract values that represent sets of possible inputs
3. There are [different levels of abstractions](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py) that are used during tracing, and the kind of abstraction used for a particular function tracing depends on the kind of transformation is done. 
4. By default, jit traces your code on the **`ShapedArray`** abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

Coming to the above code and why it failed, in this case, the value of `x` isn't concrete while tracing. As a result when we hit a line like `if x % 2 == 0`, the expression `x % 2` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set {True, False}. **When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing!** 

Let's take one more example, this time involving a loop

In [14]:
def multiply_n_times(x, n):
    count = 0
    res = 1
    while count < n:
        res = res * x
        count +=1 
    return x


try:
    val = jit(multiply_n_times)(2, 5)
except Exception as ex:
    print(type(ex).__name__, ex)

ConcretizationTypeError Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function multiply_n_times at <ipython-input-14-b9a83545696a>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


If the computation inside the loop is pretty expensive, you can still jit some part of the function body. Let's see it in action

In [15]:
# Jitting the expensive computational part
def multiply(x, i):
    return x * i

# Specifying the static args
multiply_jit = jit(multiply, static_argnums=0)

# Leaving it as it as
def multiply_n_times(x, n):
    count = 0
    res = 1
    while count < n:
        res = multiply_jit(x, res)
        count += 1
    return res


%timeit multiply_n_times(2, 5)

24.5 µs ± 879 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Caching

When you `jit` a function, it gets compiled on the first call. Any subsequent calls to the jitted function reuse the cached code. You pay the price once! 

If we need to JIT a function that has a condition on the value of an input, we can tell JAX to make a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values. 


<div class="alert alert-warning">
    <strong>Warning: Don't do this!</strong>
</div>




```python
def multiply(x, i):
    return x * i

def multiply_n_times(x, n):
    count = 0
    res = 1
    while count < n:
        res = jit(multiply)(x, res)
        count += 1
    return res


print(multiply_n_times(2, 5))

```

Doing that effectively creates a new jit transformed object at each call that will get compiled each time instead of reusing the same cached function.


That's it for Part-7! More in the next tutorial!