In [None]:
pip install --upgrade "jax[cpu]"
# or for GPU:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

### 1. Eg execution

In [16]:
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
print(x + 2)


[3. 4. 5.]


### 2. Vectorization over Batches

In [17]:
from jax import vmap
import jax.numpy as jnp

def f(x):
    return x ** 2 + 1

x = jnp.arange(5)
print(vmap(f)(x))  # [1, 2, 5, 10, 17]


[ 1  2  5 10 17]


### 3. Auto-differentiation

auto differentiation means is ability of a system that can automatically compute derivatives (gradients) of functions. Because of auto differentiation Jax can reduce errors in implementation of equations.
eg:
```python
def f(x):
    return x**2 + 3*x + 1
```

f (x)=2x+


In [18]:
import time
import numpy as np
start_time = time.time()

def df(x):
    return 2*x + 3  # Manually derived

print(df(5))
print("--- %s seconds ---" % (time.time() - start_time))


13
--- 0.00015306472778320312 seconds ---


In [19]:

start_time = time.time()
from jax import grad

def f(x):
    return x**2 + 3*x

df = grad(f)
print(df(5.0)) 
print("--- %s seconds ---" % (time.time() - start_time))


13.0
--- 0.004549741744995117 seconds ---


### 4. jit Compiles once, then run very fast

In [20]:
## without jit
start = time.time()
def normal_square(x):
    return x**2 + 3*x
print(normal_square(3.0))
print("Normal call:", time.time() - start)


18.0
Normal call: 0.00012493133544921875


In [21]:
from jax import jit, grad
import jax.numpy as jnp
import time

@jit
def square(x):
    return x**2 + 3*x

grad_square = jit(grad(square))

start = time.time()
print(grad_square(3.0))  # First call will be slower
print("First call:", time.time() - start)

start = time.time()
print(grad_square(3.0))  # Second call will be fast
print("Second call:", time.time() - start)


9.0
First call: 0.020296335220336914
9.0
Second call: 0.0002601146697998047
