<a href="https://colab.research.google.com/github/AccelAI/Jax-Intro-Tutorial/blob/main/JaxIntroDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up

The first thing we need to do is make sure that we have all the packages that we need. If you are installing on your own machine, make sure you are running python 3 and then you can install anaconda for everything we need (and more) or each of the following individually. 

In [None]:
import sys
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install torch
!{sys.executable} -m pip install jax
!{sys.executable} -m pip install jaxlib
!{sys.executable} -m pip install tensorflow

Next lets import our first jax modules and do some basic math to test out Jax vs NumPy

In [3]:
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, random, lax


In [4]:
# Prevent GPU/TPU warning.
import jax; jax.config.update('jax_platform_name', 'cpu')

## Basics

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

### Numpy vs Jax

In [None]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
print(x)

### Basic Multiplication

In [36]:
a = jnp.array(3.)
b = jnp.array([3., 2., 1.])
c = jnp.array([5., 5., 5.])

def dot_mul(a, b, c):
  return (a * jnp.dot(b, c))

In [37]:
dot_mul(a, b, c)

DeviceArray(90., dtype=float32)

## Gradient

using jax.grad with out function dot_mul we can get the gradient of our function return with respect to a parameter

In [38]:
grad(dot_mul)(a, b, c)

DeviceArray(30., dtype=float32)

In [39]:
grad(dot_mul, argnums=[1])(a, b, c)

(DeviceArray([15., 15., 15.], dtype=float32),)

In [40]:
grad(dot_mul, argnums=[2])(a, b, c)

(DeviceArray([9., 6., 3.], dtype=float32),)

## VMap
Dealing with batches of data, ie matrices 

In [None]:
a2 = jnp.array([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.], [4., 4., 4.]])
a2

Gradient is only designed for scalar output which does not support batched data while vmap is built for working with matrices

In [None]:
grad(dot_mul)(a, a2, b)

In [None]:
vmap(dot_mul, in_axes=(None, None, 0))(a, b, a2)

If you are looking to get the gradients across a batch of data you can combine the two

In [None]:
vmap(grad(dot_mul), in_axes=(None, None, 0))(a, b, a2)

## Just in Time compilation (JIT)
As python is an interpreted language it can be quite slow, especially when dealing with large data sets. To fix this we use JIT which is just like tf-function or autograph in tensorflow and typescript in pytorch

In [None]:
jit(dot_mul)(a, b, c)

In [None]:
jax.jit(jax.vmap(jax.grad(dot_mul), in_axes=(None, None, 0)))(a, b, a2)

We can actually see what JIT is doing with jaxpr

In [None]:
jax.jit(dot_mul) # returns compiled version of our function

In [None]:
jax.make_jaxpr(jax.jit(dot_mul))(a, b, c)

## PMAP (Parallel Map)
Distributing computation across hardware (GPUs / TPUs)

In [5]:
jax.local_devices()

[<jaxlib.xla_extension.Device at 0x7fb20c4b16b0>]