### The magic sauce over numpy

Transforms are what Jax adds on top of numpy. This is more of an advanced thing to understand and if you wish to skip and come back later, you can do it. It's fine

So about transforms in Jax. There are these transforms: `jit`, `grad`, `pmap`, `vmap`.

This notebook is about jit.

### jit()

Jax calls itself accelerated numpy. Or a faster numpy. Or however you interprete it. jit is one of the key components in this acceleration. What jit does is that it keeps a copy of your functions in the cache which can be executed faster than a regular function call. But how? 

Okay. You asked for it. 


#### Translating computer programs (or a rather CS101 refresher)

Computers only understand 1's and 0's (in other words, binary numbers). So when you try to run your code, it needs to be translated into binary numbers first. This process isn't straightforward. There are multiple levels of what CS junkies call abstractions. One possible abstraction hierarchy can be like this (or what it used to look like before [llvm](https://llvm.org/) and [clang](https://clang.llvm.org/) came). 


```
-------> Your code 
        -------> **Translator**
                -------> Assembler (based on the instruction Set that your CPU / GPU maker made, also known as ISA) 
                        -------> Binary (executable)
```



This translator part in the middle can be either a Compiler or an Interpreter. Compiler takes your entire code file and translates it at once. There's a whole field of study in CS on compiler design if you want to look into that abyss. Anyway, since compilers translate or compile the whole file at once, you can add various optimizations (memory, speed) and enforce checks at translation or compile time to make sure your code has reduced amount of errors. Compiled programs are also, by comparison, faster. (Video Games, Operating Systems are prime examples.) **AND you need to compile only once!** (terms and conditions apply)

Python on the other hand is interpreted. Interpreter translates one line at a time and **you have to intereprete everytime** you want to run your code. The process is slow and you can't enforce the same checks and optimizations like a compiler. Everything is known when the code is encountered so errors can't be reduced beforehand. So when Jax tries to cache your code for faster recall, the regular python design holds it back. Instead, Jax uses a compiler called XLA, which compiles your python code and caches it.


And oh, JIT means just in time compilation (thank me later :) )

So let's check how much of a performance benefit jit adds on top of regular code!

In [1]:
import jax as J
from jax import jit, random
import jax.numpy as jnp

In [2]:
key = random.PRNGKey(123)

x = random.normal(key, shape=(9, ))
x

DeviceArray([-0.10502207, -0.56205004, -0.56485987, -1.7063935 ,
             -1.3647023 , -0.42215332,  1.0077653 ,  0.9922631 ,
             -0.61236995], dtype=float32)

In [3]:
def f(x):
    return jnp.sin(x) + jnp.cos(x)

%timeit f(x)

131 µs ± 20.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


With jit?

In [4]:
f_with_jit = jit(f)

%timeit f_with_jit(x)

55.8 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


See that wall time difference ? That's like real fast!

#### Alternate syntax

In [6]:
@jit
def f(x):
    return jnp.sin(x) + jnp.cos(x)

%timeit f(x)

39.6 µs ± 6.28 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### But it's not all sunshine....

Well caching and speeding up is nice but if your code changes, jit will have to compile it again. This adds some overhead and repeated compilations can make code execution slow, for example inside a loop. 

### What my python code becomes after XLA compilation?

In [8]:
from jax import make_jaxpr

In [9]:
make_jaxpr(f)(x)

{ lambda ; a:f32[9]. let
    b:f32[9] = xla_call[
      call_jaxpr={ lambda ; c:f32[9]. let
          d:f32[9] = sin c
          e:f32[9] = cos c
          f:f32[9] = add d e
        in (f,) }
      name=f
    ] a
  in (b,) }

This is what XLA does to your code and what Jax will see while executing it from a cache when flagged with jit. 