## Colab TPU Setup


If you’re running this code in Google Colab, be sure to choose TPU

In [None]:
import jax.tools.colab_tpu
jax.devices()

## The basics

In [None]:
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

from typing import NamedTuple, Tuple
import functools

In [None]:
x = np.arange(5)
w = np.array([2., 3., 4.])

In [None]:
def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

In [None]:
convolve(x,w)

Now, let’s convert our convolve function into one that runs on entire batches of data.

In [None]:
n_devices = jax.local_device_count()

In [None]:
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

In [None]:
print(xs)

In [None]:
print(ws)

In [None]:
jax.vmap(convolve)(xs, ws)

To spread out the computation across multiple devices, just replace jax.vmap with jax.pmap

In [None]:
jax.pmap(convolve)(xs,ws)

In [None]:
jax.pmap(convolve)(xs, jax.pmap(convolve)(xs,ws))

## Specifying in_axes

In [None]:
jax.pmap(convolve, in_axes=(0, None))(xs, w)

## pmap and jit

jax.pmap JIT-compiles the function given to it as part of its operation, so there is no need to additionally jax.jit it.

## Communication between devices

In [None]:
def normalized_convolution(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  output = jnp.array(output)
  return output / jax.lax.psum(output, axis_name='p')

In [None]:
jax.pmap(normalized_convolution, axis_name='p')(xs, ws)

In [None]:
jax.vmap(normalized_convolution, axis_name='p')(xs, ws)

## Nesting jax.pmap and jax.vmap

In [None]:
class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray

In [None]:
def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)

In [None]:
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * xs + params.bias
  return jnp.mean((pred - ys) ** 2)

In [None]:
LEARNING_RATE = 0.005

In [None]:
# So far, the code is identical to the single-device case. Here's what's new:

# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
  """Performs one SGD update step on params using the given data."""

  # Compute the gradients on the given minibatch (individually on each device).
  loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

  # Combine the gradient across all devices (by taking their mean).
  grads = jax.lax.pmean(grads, axis_name='num_devices')

  # Also combine the loss. Unnecessary for the update, but useful for logging.
  loss = jax.lax.pmean(loss, axis_name='num_devices')

  # Each device performs its own update, but since we start with the same params
  # and synchronise gradients, the params stay in sync.
  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grads)

  return new_params, loss

In [None]:
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

In [None]:
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

In [None]:
type(replicated_params.weight)

In [None]:
def split(arr):
  """Splits the first axis of `arr` evenly across the number of devices."""
  return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

print(type(x_split))

In [None]:
def type_after_update(name, obj):
  print(f"after first `update()`, `{name}` is a", type(obj))

# Actual training loop.
for i in range(1000):

  # This is where the params and data gets communicated to devices:
  replicated_params, loss = update(replicated_params, x_split, y_split)

  # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
  # indicating that they're on the devices.
  # `x_split`, of course, remains a NumPy array on the host.
  if i == 0:
    type_after_update('replicated_params.weight', replicated_params.weight)
    type_after_update('loss', loss)
    type_after_update('x_split', x_split)

  if i % 100 == 0:
    # Note that loss is actually an array of shape [num_devices], with identical
    # entries, because each device returns its copy of the loss.
    # So, we take the first element to print it.
    print(f"Step {i:3d}, loss: {loss[0]:.3f}")


# Plot results.

# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))

In [None]:
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend()
plt.show()

## Aside: hosts and devices in JAX

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()