### The magic sauce over numpy

Transforms are what Jax adds on top of numpy. This is more of an advanced thing to understand and if you wish to skip and come back later, you can do it. It's fine

So about transforms in Jax. There are these transforms: `jit`, `grad`, `pmap`, `vmap`.

This notebook is about jit.

### jit()

Jax calls itself accelerated numpy. Or a faster numpy. Or however you interprete it. jit is one of the key components in this acceleration. What jit does is that it keeps a copy of your functions in the cache which can be executed faster than a regular function call. But how? 



#### Translating computer programs (or a rather CS101 refresher)

Computers only understand 1's and 0's (in other words, binary numbers). So when you try to run your code, it needs to be translated into binary numbers first. This process isn't straightforward. There are multiple levels of what CS junkies call abstractions. One possible abstraction hierarchy can be like this (or what it used to look like before [llvm](https://llvm.org/) and [clang](https://clang.llvm.org/) came). 


```
-------> Your code 
        -------> **Translator**
                -------> Assembler (based on the instruction Set that your CPU / GPU maker made, also known as ISA) 
                        -------> Binary (executable)
```



This translator part in the middle can be either a Compiler or an Interpreter. Compiler takes your entire code file and translates it at once. There's a whole field of study in CS on compiler design if you want to look into that abyss. Anyway, since compilers translate or compile the whole file at once, you can add various optimizations (memory, speed) and enforce checks at translation or compile time to make sure your code has reduced amount of errors. Compiled programs are also, by comparison, faster. (Video Games, Operating Systems are prime examples.) **AND you need to compile only once!** (terms and conditions apply)

Python on the other hand is interpreted. Interpreter translates one line at a time and **you have to intereprete everytime** you want to run your code. The process is slow and you can't enforce the same checks and optimizations like a compiler. Everything is known when the code is encountered so errors can't be reduced beforehand. So when Jax tries to cache your code for faster recall, the regular python design holds it back. Instead, Jax uses a compiler called XLA, which compiles your python code and caches it.


And oh, JIT means just in time compilation (thank me later :) )

So let's check how much of a performance benefit jit adds on top of regular code!

In [1]:
import jax as J
from jax import jit, random
import jax.numpy as jnp


In [2]:
key = random.key(123)

x = random.normal(key, shape=(9, ))
x


Array([-0.10502207, -0.56205004, -0.56485987, -1.7063935 , -1.3647023 ,
       -0.42215332,  1.0077653 ,  0.9922631 , -0.61236995], dtype=float32)

In [3]:
def f(x):
    return jnp.sin(x) + jnp.cos(x)

%timeit f(x)


53.3 µs ± 79.2 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


With jit?

In [4]:
f_with_jit = jit(f)

%timeit f_with_jit(x)


16.3 µs ± 36.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


See that wall time difference ? That's like real fast!

#### Alternate syntax

In [5]:
@jit
def f(x):
    return jnp.sin(x) + jnp.cos(x)

%timeit f(x)


16.4 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### But it's not all sunshine....

Well caching and speeding up is nice but if your code changes, jit will have to compile it again. This adds some overhead and repeated compilations can make code execution slow, for example inside a loop. Also, not everything in python is supported in jax transformations, such as jit. Your code should be side-effect free (for starters, not have pythonic `loops`, `print` statements and `if-else` conditionals) and jax enforces a very functional programming approach. As a result you may not be able to jit every part of your code. 

This is where jax is different from Tensorflow or Pytorch. You've to think which parts of your code are going to be covered by transformations, which parts only take care of computations and which parts take care of loops and conditions.

#### What my python code becomes after XLA compilation?

In [6]:
from jax import make_jaxpr


In [7]:
make_jaxpr(f)(x)


{ lambda ; a:f32[9]. let
    b:f32[9] = pjit[
      jaxpr={ lambda ; c:f32[9]. let
          d:f32[9] = sin c
          e:f32[9] = cos c
          f:f32[9] = add d e
        in (f,) }
      name=f
    ] a
  in (b,) }

This is what XLA does to your code and what Jax will see while executing it from a cache when flagged with jit.

#### Now what if you do include some side effects or code unwanted by jax, for example a print statement?

In [8]:
@jit
def side_effect_func(x, y):
    print(x)
    x = x + y # yes this is also a side effect
    return x

make_jaxpr(side_effect_func)(x, x)


Traced<ShapedArray(float32[9])>with<DynamicJaxprTrace(level=2/0)>


{ lambda ; a:f32[9] b:f32[9]. let
    c:f32[9] = pjit[
      jaxpr={ lambda ; d:f32[9] e:f32[9]. let f:f32[9] = add d e in (f,) }
      name=side_effect_func
    ] a b
  in (c,) }

In [9]:
@jit
def without_side_effect(x, y):
    return x + y

make_jaxpr(without_side_effect)(x, x)


{ lambda ; a:f32[9] b:f32[9]. let
    c:f32[9] = pjit[
      jaxpr={ lambda ; d:f32[9] e:f32[9]. let f:f32[9] = add d e in (f,) }
      name=without_side_effect
    ] a b
  in (c,) }

See that there's `Traced<ShapedArray(float32[9])>with<DynamicJaxprTrace(level=1/1)>` in the jaxpr output for the function with side effects. So instead of being a pure jax type, it now has a tracer type, which has to adapt to the induced side effects, the print statement and modification of x. 

This may not be alarm bells to you as long as you're getting correct results but this may result in inconsistent behavior acorss devices. (which jax wants to avoid at all costs). 

You can read more here about [jax insisting on pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).

TL;DR : Just know that Jax is functional at core and only likes pure functions in your code. Anything else will probably run but jax can't provide you any guarantee about correct or consistent results.

## What if I want some parameters of my function to be not compiled?

You may encounter this scenario if your function has to build an array with a shape which is not known prior (supplied as a parameter). Consider the following function:

In [10]:
def multinomial_sample(rng_key, logits, n_samples):
    return random.categorical(rng_key, logits, shape=[n_samples, ])

key, subkey = random.split(key)
multinomial_sample(subkey, x, 50)


Array([6, 7, 7, 6, 0, 1, 6, 2, 6, 7, 6, 6, 6, 0, 7, 1, 1, 0, 2, 6, 0, 7,
       8, 6, 6, 6, 7, 3, 6, 6, 1, 3, 7, 7, 4, 8, 6, 7, 8, 1, 7, 6, 6, 7,
       7, 2, 2, 6, 6, 6], dtype=int32)

jit'ing this function is not allowed, since `n_samples` can change and the array returned from the function won't have a fixed size on each call. 

In [11]:
jit(multinomial_sample)(subkey, x, 50)


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function multinomial_sample at /tmp/ipykernel_5821/897393480.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n_samples.

The solution here is to tell jit that you want it to ignore the n_samples parameter. For that you have to use `static_argnums`. 

You can read more at: https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html


In [12]:
jit(multinomial_sample, static_argnums=2)(subkey, x, 50)


Array([6, 7, 7, 6, 0, 1, 6, 2, 6, 7, 6, 6, 6, 0, 7, 1, 1, 0, 2, 6, 0, 7,
       8, 6, 6, 6, 7, 3, 6, 6, 1, 3, 7, 7, 4, 8, 6, 7, 8, 1, 7, 6, 6, 7,
       7, 2, 2, 6, 6, 6], dtype=int32)

In [13]:
jit(multinomial_sample, static_argnums=2)(subkey, x, 150)


Array([7, 7, 7, 7, 7, 0, 7, 7, 6, 6, 8, 6, 4, 2, 7, 6, 6, 7, 7, 7, 7, 3,
       6, 1, 7, 2, 7, 3, 6, 7, 6, 7, 6, 8, 7, 7, 0, 6, 0, 5, 7, 7, 7, 0,
       6, 7, 6, 0, 2, 5, 6, 6, 5, 7, 8, 7, 0, 7, 4, 2, 0, 6, 3, 6, 0, 7,
       2, 0, 7, 0, 2, 5, 2, 0, 7, 6, 8, 2, 7, 7, 7, 2, 7, 7, 7, 6, 0, 6,
       5, 6, 7, 0, 8, 5, 6, 6, 2, 6, 2, 0, 3, 7, 6, 8, 0, 6, 7, 8, 0, 7,
       4, 6, 6, 6, 7, 6, 7, 5, 6, 6, 7, 5, 7, 6, 3, 7, 7, 6, 7, 6, 7, 7,
       7, 0, 6, 4, 7, 7, 6, 6, 0, 7, 0, 6, 6, 7, 7, 6, 6, 7], dtype=int32)