In [12]:
from typing import NamedTuple
import functools
import jax
from jax import numpy as jnp
import numpy as np


class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray

def init(rng) -> Params:
    weights_key, bias_key = jax.random.split(rng)
    weight = jax.random.normal(weights_key, ())
    bias = jax.random.normal(bias_key, ())
    return Params(weight, bias)

def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
    pred = params.weight * xs + params.bias
    return jnp.mean((pred - ys)**2)

lr = 5e-3

def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray):
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
    grads = jax.lax.pmean(grads, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')
    new_params = jax.tree_map(lambda param, g: param - g * lr, params, grads)

    return new_params, loss



In [13]:
jax.local_device_count()

5

In [14]:
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

params = init(jax.random.PRNGKey(0))
# n_devices = jax.local_device_count()
n_devices = 4

replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

In [15]:
def split(arr):
  """Splits the first axis of `arr` evenly across the number of devices."""
  return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

type(x_split)

numpy.ndarray

In [17]:
def type_after_update(name, obj):
    print(f"After 1st 'update()', '{name}' is a {type(obj)}")

for i in range(1000):
    # replicated_params, loss = update(replicated_params, x_split, y_split)
    replicated_params, loss = jax.pmap(update, axis_name='num_devices')(replicated_params, x_split, y_split)
    
    if i == 0:
        type_after_update('replicated_params.weight: ', replicated_params.weight)
        type_after_update('loss: ', loss)
        type_after_update('x_split: ', x_split)
    if i % 100 == 0:
        print(f"Step {i:3d}, loss: {loss[0]:.3f}")

params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))



After 1st 'update()', 'replicated_params.weight: ' is a <class 'jaxlib.xla_extension.ArrayImpl'>
After 1st 'update()', 'loss: ' is a <class 'jaxlib.xla_extension.ArrayImpl'>
After 1st 'update()', 'x_split: ' is a <class 'numpy.ndarray'>
Step   0, loss: 0.212
Step 100, loss: 0.212
Step 200, loss: 0.212
Step 300, loss: 0.212
Step 400, loss: 0.212
Step 500, loss: 0.212
Step 600, loss: 0.212
Step 700, loss: 0.212
Step 800, loss: 0.212
Step 900, loss: 0.212
