<a href="https://colab.research.google.com/github/jysonganan/Gaussian-Processes/blob/master/Basics_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 16039940570321362052, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 16550877450029741743
 physical_device_desc: "device: XLA_CPU device"]

In [4]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 16145914453162090499, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 17979539285112938045
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 13771497194017127324
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 1334706176
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 3535907538658256442
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

In [0]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random



In [0]:
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)


Even simple matrix multiplication can be speed up quite a bit.

In [10]:
# Generate a random matrix
x = random.uniform(key, (1000, 1000))
# Compare running times of 3 different matrix multiplications
%time y = onp.dot(x, x)
%time y = np.dot(x, x)
%time y = np.dot(x, x).block_until_ready()

CPU times: user 80.8 ms, sys: 15 ms, total: 95.7 ms
Wall time: 62.2 ms
CPU times: user 14.4 ms, sys: 8.1 ms, total: 22.5 ms
Wall time: 11.3 ms
CPU times: user 183 ms, sys: 86.1 ms, total: 269 ms
Wall time: 147 ms


## jit, grad, vmap

##### jit - speed up
##### grad - compute gradients
##### vmap - easy for batching

In [11]:
def ReLU(x):
  return np.maximum(0,x)

jit_ReLU = jit(ReLU)

%time out = ReLU(x).block_until_ready()

# Call jitted version to compile for evaluation time!
%time jit_ReLU(x).block_until_ready()

%time out = jit_ReLU(x).block_until_ready()

CPU times: user 73.9 ms, sys: 1.84 ms, total: 75.7 ms
Wall time: 75.9 ms
CPU times: user 22.6 ms, sys: 1.88 ms, total: 24.5 ms
Wall time: 23.9 ms
CPU times: user 2.08 ms, sys: 11 µs, total: 2.09 ms
Wall time: 1.26 ms


The next tool in JAX kit is grad. It's inherited from Autograd package.

In [12]:
def FiniteDiffGrad(x):
  return np.array((ReLU(x+1e-3)-ReLU(x-1e-3))/(2*1e-3))

# Compare the Jax gradient with a finite difference approximation
print("Jax Grad: ", jit(grad(jit(ReLU)))(2.))
print("FD Gradient:", FiniteDiffGrad(2.))

Jax Grad:  1.0
FD Gradient: 0.99998707


vmap - easy for batching

Let’s say you have a 100 dimensional feature vector and want to process it by a linear layer with 512 hidden units & your ReLU activation. And let’s say you want to compute the layer activations for a batch with size 32.

In [0]:
batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of vectors to process
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
params = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim, ))]

def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

out = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
out = batch_version_relu_layer(params, X)
out = vmap_relu_layer(params, X)

We have stacked the vectors into a matrix such that our input has dimensions (batch_dim, feature_dim). We therefore need to provide vmap with batch dimension (0) in order to properly parallelize the computations. out_axes than specifies how to stack the individual samples outputs. In order to keep things consistent, we choose the first dimension to remain the batch dimension.

####References: 
https://roberttlange.github.io/posts/2020/03/blog-post-10/