In [18]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

## Parallelism in JAX<br>
#### Parallelism in JAX is handled by a function called pmap

In [19]:
# Get ourselves some TPU Goodness
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

jax.devices()

RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

In [20]:
devices = jax.devices()
print(f"Available devies: {devices}")

Available devies: [CpuDevice(id=0)]


In [21]:
# If above code does not show GPU, jax is not properly configured

# Check default backend
print(f"Default backend: {jax.default_backend()}")

Default backend: cpu


### Note: JAX does not support native GPU installations on Windows<br> 
#### It only supports GPU acceleration on Windows through the use of WSL2 (Windows Subsystem for Linux 2). 
#### For direct installation on Windows, only CPU support is officially available. 

#### (We can still learn JAX concepts on CPU or use Colab TPUs)

In [29]:
# Simple running example
x = np.arange(5)                # Signal
w = np.array([2., 3., 4.])      # Kernel/Window

# 1D Convolution
def convolve(w, x):
    op = []

    for i in range(1, len(x)-1):
        op.append(jnp.dot(x[i-1:i+2], w))

    return jnp.array(op)

- repr(): String Representation of Object
- Shows what the object is (type and identity) unambiguously

In [30]:
result = convolve(w, x)
print(repr(result))

Array([11., 20., 29.], dtype=float32)


In [31]:
n_devices = jax.local_device_count()
print(f"Number of available devices: {n_devices}")

Number of available devices: 1


Here you would see more than one (depending upon how many cores your GPU/TPU has)

In [32]:
# Imagine we have a heavier load (a batch of examples)
xs = np.arange(5*n_devices).reshape(-1, 5)
ws = np.stack([w]*n_devices)

print(xs.shape, ws.shape)

(1, 5) (1, 3)


#### First Way to Optimize: Simply use vmap!

In [34]:
vmap_result = jax.vmap(convolve)(ws, xs)
# By default, the in_axes argument in set to (0,0)
print(repr(vmap_result))

Array([[11., 20., 29.]], dtype=float32)


- vmap automatically vectorizes a function to work on batches of data, without us having to write explicit loops or manual batching code
- in_axes tells vmap with acis to map over for each argument
- Eg: jax.vmap(func, in_axes=(arg1_axis, arg2_axis, ...))(arg1, arg2, ...)
- 0 = map over axis 0 (first dim, which is usually the batch dimension)
- 1 = map over axis 1
- None: don't map, use same value for all (broadcasting)

In [36]:
# The amazing part is, if you just swap vmap for pmap, you are now 
# running on multiple devices

pmap_result = jax.pmap(convolve)(ws, xs)
print(repr(pmap_result))

Array([[11., 20., 29.]], dtype=float32)


- All of this happens independently
- There are no cross-device communication costs

In [37]:
# Same operation, but smarter way
# We dont have to manually broadcast w

sol = jax.pmap(convolve, in_axes=(None, 0))(w, xs)
print(repr(sol))

Array([[11., 20., 29.]], dtype=float32)


#### All this is great, but sometimes we do require communication between devices!<br><br>
Let's see how this is handled

In [40]:
# Same Convolution Example, but this time
# we communicate across devices to normalize the outputs

def normalized_convolve(w, x):
    output = []
    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))
    output = jnp.array(output)
    # Same as before till here

    output = output/jax.lax.psum(output, axis_name='batch_dim')

    return output

jax.lax.psum() --> "Parallel Sum": Sums values across multiple devices (GPUs/TPUs)<br><br>
Takes in arguments:
- x: value to sum (array or scalar)
- axis_name: which parallel axis to sum across

In [44]:
res_pmap = jax.pmap(normalized_convolve, axis_name='batch_dim', in_axes=(None, 0))(w, xs)
res_vmap = jax.vmap(normalized_convolve, axis_name='batch_dim', in_axes=(None, 0))(w, xs)
# axis_name just gives an arbitrary name to mapped axis

In [45]:
print(repr(res_pmap))
print(repr(res_vmap))
print(f"Verify output is normalized: {sum(res_pmap[:,0])}")

Array([[1., 1., 1.]], dtype=float32)
Array([[1., 1., 1.]], dtype=float32)
Verify output is normalized: 1.0


#### Couple more useful functions

In [46]:
# Aside from grads, we also need to return losses

def sum_squared_error(x, y):
    return sum((x-y)**2)

x = jnp.arange(4, dtype=jnp.float32)
y = x + 0.1

In [48]:
# Efficient Way to return grad and loss value
print(x)
print(jax.value_and_grad(sum_squared_error)(x,y))

[0. 1. 2. 3.]
(Array(0.03999997, dtype=float32), Array([-0.2       , -0.20000005, -0.19999981, -0.19999981], dtype=float32))


In [51]:
# Sometimes, the loss function also needs to return intermediate results

def sum_squared_error_with_aux(x, y):
    return sum((x-y)**2), (x-y)

In [None]:
# This will throw Error - Gradient only works for scalar output functions
jax.grad(sum_squared_error_with_aux)(x,y) 

In [52]:

jax.grad(sum_squared_error_with_aux, has_aux=True)(x,y) 

(Array([-0.2       , -0.20000005, -0.19999981, -0.19999981], dtype=float32),
 Array([-0.1       , -0.10000002, -0.0999999 , -0.0999999 ], dtype=float32))