**S03P02_tutorial_parallel_evaluation_in_jax.ipynb**

Arz

2024 APR 21 (SUN)

reference:
https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

In [2]:
import numpy as np

In [3]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [4]:
%xmode minimal

Exception reporting mode: Minimal


facilities built into JAX for SPMD (single-program, multiple-data) code.

# TPU setup

In [5]:
jax.devices()

[cuda(id=0)]

# the basics

## ex) convolve

In [6]:
def convolve(x, y):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i-1:i+2], y))
    return jnp.array(output)

In [7]:
x = jnp.arange(5)
y = jnp.array([2., 3., 4.])

In [8]:
convolve(x, y)

Array([11., 20., 29.], dtype=float32)

In [19]:
num_devices = jax.local_device_count()

x_batch = np.arange(5*num_devices).reshape(-1, 5)
y_batch = np.stack([y]*num_devices)

if num_devices > 1:
    y_batch[1] = [7, 4, 2]

print(x_batch)
print(y_batch)

[[0 1 2 3 4]]
[[2. 3. 4.]]


### vmap

In [20]:
jax.vmap(convolve)(x_batch, y_batch)

Array([[11., 20., 29.]], dtype=float32)

### pmap

In [21]:
jax.pmap(convolve)(x_batch, y_batch)

Array([[11., 20., 29.]], dtype=float32)

❓ if we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs.

In [22]:
jax.pmap(convolve)(x_batch, jax.pmap(convolve)(x_batch, y_batch))

Array([[ 78., 138., 198.]], dtype=float32)

# specifying in_axes

⚠️ note: unlike vmap's case, _axes option must be a tuple, not a vector.

e.g.) in_axes=[1, 0]: throws error, in_axes=(1, 0): OK.

In [24]:
jax.pmap(convolve, in_axes=1, out_axes=1)(x_batch.T, y_batch.T) 

Array([[11.],
       [20.],
       [29.]], dtype=float32)

## case: dimension inconsistency

In [30]:
# jax.pmap(convolve, in_axes=1, out_axes=0)(x_batch.T, y_batch)  # forbidden: throws error

# dimension inconsistency

### fix

in_axes=[axis to take for arg1 (x_batch), axis to take for arg2 (y_batch)]

In [37]:
jax.pmap(convolve, in_axes=(1, 0), out_axes=0)(x_batch.T, y_batch)

Array([[11., 20., 29.]], dtype=float32)

In [38]:
jax.pmap(convolve, in_axes=(1, 0), out_axes=1)(x_batch.T, y_batch)

Array([[11.],
       [20.],
       [29.]], dtype=float32)

## case: when only one of the arguments is batched

In [23]:
jax.pmap(convolve, in_axes=(0, None))(x_batch, y)

Array([[11., 20., 29.]], dtype=float32)

# pmap and jit

jax.pmap JIT-compiles the function given to it as part of its operation.

**so there is no need to additionally jax.jit it**.

# communication between devices

sometimes we need to pass information between the devices.

## ex) normalized convolve

❓

the collective operation here is **jax.lax.psum**

In [39]:
def normalized_convolve(x, y):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i-1:i+2], y))
    output = jnp.array(output)
    return output/jax.lax.psum(output, axis_name='p')

In [41]:
result = jax.pmap(normalized_convolve, axis_name='p')(x_batch, y_batch)

print(result)

[[1. 1. 1.]]


sum of each column should equal 1.

In [42]:
jnp.sum(result, axis=0)

Array([1., 1., 1.], dtype=float32)

## vmap case

jax.vmap can also have axis_name.

In [43]:
result = jax.vmap(normalized_convolve, axis_name='p')(x_batch, y_batch)

print(result)

[[1. 1. 1.]]


In [44]:
jnp.sum(result, axis=0)

Array([1., 1., 1.], dtype=float32)

# nesting jax.pmap and jax.vmap

the reason we specify axis_name as a string is so we can use collective operations when nesting jax.pmap and jax.vmap.

in general, jax.pmap and jax.vmap can be nested in any order, and with themselves (so you can have a pmap within another pmap, for instance).

## ex) regression training loop with data parallelism

each batch is split into sub-batches which are evaluated on separate devices.

In [45]:
from typing import NamedTuple
import functools

In [50]:
class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray

def initialize(key) -> Params:
    weight_key, bias_key = jax.random.split(key)
    weight = jax.random.normal(weight_key, ())
    bias = jax.random.normal(bias_key, ())
    return Params(weight, bias)

def loss_function(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    y_pred = params.weight*x + params.bias
    return jnp.mean((y - y_pred)**2)

learning_rate = 0.005

#
@functools.partial(jax.pmap, axis_name="num_devices")
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
    # compute loss and gradients on each given minibatch (individually on each device using pmap)
    losses, gradss = jax.value_and_grad(loss_function)(params, x, y)

    # take the mean of losses and gradients 
    loss = jax.lax.pmean(losses, axis_name="num_devices")
    grads = jax.lax.pmean(gradss, axis_name="num_devices")

    # update params
    new_params = jax.tree_map(lambda param, grad: param - grad*learning_rate, params, grads)

    return new_params, loss

### data & setting

In [73]:
num_devices = jax.local_device_count()
if num_devices > 1:
    x = np.random.normal(size=(num_devices*16,))
    noise = np.random.normal(scale=0.1, size=(num_devices*16,))
else:
    x = np.random.normal(size=(100,))
    noise = np.random.normal(scale=0.1, size=(100,))

y = 3*x - 1 + noise

**initialize parameters and replicate across devices.**

In [74]:
params = initialize(jax.random.key(123))
num_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x]*num_devices), params)

# params will be comunicated to each device when update() is first called, 
# then each copy of params will stay on its own device subsequently. 

In [75]:
print(replicated_params)
print(type(replicated_params))
print(type(replicated_params.weight), type(replicated_params.bias))

Params(weight=Array([-0.12120728], dtype=float32), bias=Array([-1.7093881], dtype=float32))
<class '__main__.Params'>
<class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>


**split data into minibatches such that each minibatch has the size of num_devices**

In [76]:
def split_data(x):
    return x.reshape(num_devices, x.shape[0]//num_devices, *x.shape[1:])

In [77]:
x_batch = split_data(x)
y_batch = split_data(y)

In [78]:
print(x_batch)
print(type(x_batch))

# x_batch stays on CPU.
# each call of update() communicate data samples from CPU to devices.

[[ 0.33596082  0.4160086   0.149929    0.75965705  2.14434221  2.14092551
   0.69953455  0.78254533  1.12916677 -0.78112501  1.00605159  0.66764108
  -1.20902483  1.07137935  0.39122226  0.49429518 -1.87083096 -1.64257915
  -0.23276109 -0.63230313 -0.33630381 -0.88661858  0.60621643 -0.02192579
   1.04646006  1.56940571  0.61744109 -0.47373537  0.36797314 -1.01429225
   1.56518501  0.88358108  0.19960978  0.49541447  0.02417953 -0.49567561
   0.25201306 -0.82571614 -0.11546714  0.40030643 -0.89492143  0.79662802
   1.30199593  0.64503735  0.11328555  0.10429732  1.54464023 -0.72126082
   0.11065123 -0.43703535 -0.22680218 -1.03178947 -0.37999731 -0.11700991
   1.54878456 -1.07871082 -1.94356133 -0.16686865 -0.8442943  -1.13331754
   0.24904953  1.73928313 -0.31184078  0.21499956 -1.60165727  2.02374289
  -1.21639568 -1.26595997  0.57824847 -2.50447147 -0.35102177 -0.93901348
  -1.02771758  1.20862917  0.23053017  0.04920859 -0.93937376 -1.9007063
   0.84431626  0.39518575  0.84667979 -

### training

In [79]:
def print_type(name: str, obj: object):
    print(f"{name} has type {type(obj)}")

In [80]:
for i in range(1000):
    # update
    # - this is where the params and data gets communicated to devices
    replicated_params, loss = update(replicated_params, x_batch, y_batch)

    if i == 0:
        print("after the first call of update():")
        print_type("replicated_params.weight", replicated_params.weight)
        print_type("loss", loss)
        print_type("x_batch", x_batch)

    if i%100 == 0:
        print(f"epoch {i:3d}, loss: {loss[0]:.3f}")
        print(loss)

params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))

after the first call of update():
replicated_params.weight has type <class 'jaxlib.xla_extension.ArrayImpl'>
loss has type <class 'jaxlib.xla_extension.ArrayImpl'>
x_batch has type <class 'numpy.ndarray'>
epoch   0, loss: 9.980
[9.980379]
epoch 100, loss: 1.384
[1.38366]
epoch 200, loss: 0.198
[0.19841388]
epoch 300, loss: 0.035
[0.03498678]
epoch 400, loss: 0.012
[0.01245054]
epoch 500, loss: 0.009
[0.00934265]
epoch 600, loss: 0.009
[0.008914]
epoch 700, loss: 0.009
[0.00885487]
epoch 800, loss: 0.009
[0.0088467]
epoch 900, loss: 0.009
[0.00884558]


In [81]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

In [82]:
fig = px.scatter(x=x, y=y)
fig_model = px.line(x=x, y=params.weight*x + params.bias)
fig_model.data[0].line.color = "#e02a19"
fig.add_trace(fig_model.data[0])

fig.show()

# aside: hosts & devices in JAX

**host**: CPU that manages several devices. 

- A single host can only manage so many devices (usually 8)

so when running very large parallel programs, multiple hosts are needed, and some finesse is required to manage them.

**--xla_force_host_platform_device_count=8**

❓ before importing JAX

In [84]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()

[cuda(id=0)]

this is especially useful for debugging and testing locally or even for prototyping in Colab since a CPU runtime is faster to (re-)start.