<a href="https://colab.research.google.com/github/HarounH/smol/blob/main/rl/distributed_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os

# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags
import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

print(jax.devices())

a = jnp.arange(16)
print("Array", a)
print("Device", a.device)
print("Sharding", a.sharding)

mesh = Mesh(np.array(jax.devices()), ("i",))
print(mesh)

sharding = NamedSharding(
    mesh,
    P("i"),
)

a_sharded = jax.device_put(a, sharding)
print("Sharded array", a_sharded)
print("Device", a_sharded.devices())
print("Sharding", a_sharded.sharding)
jax.debug.visualize_array_sharding(a_sharded)


[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Array [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
Device TFRT_CPU_0
Sharding SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Mesh('i': 8, axis_types=(Auto,))
Sharded array [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
Device {CpuDevice(id=0), CpuDevice(id=7), CpuDevice(id=5), CpuDevice(id=3), CpuDevice(id=1), CpuDevice(id=4), CpuDevice(id=6), CpuDevice(id=2)}
Sharding NamedSharding(mesh=Mesh('i': 8, axis_types=(Auto,)), spec=PartitionSpec('i',), memory_kind=device)


In [3]:
mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ("i", "j"))
print(mesh)
batch_size = 192
input_dim = 64
output_dim = 128
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
x_sharded = jax.device_put(x, NamedSharding(mesh, P("i", None)))
w_sharded = jax.device_put(w, NamedSharding(mesh, P(None, "j")))
b_sharded = jax.device_put(b, NamedSharding(mesh, P("j")))
out = jnp.dot(x_sharded, w_sharded) + b_sharded
print("Output shape", out.shape)
jax.debug.visualize_array_sharding(out)


Mesh('i': 4, 'j': 2, axis_types=(Auto, Auto))
Output shape (192, 128)


In [4]:
def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
    print("Local x shape", x.shape)
    print("Local w shape", w.shape)
    print("Local b shape", b.shape)
    return jnp.dot(x, w) + b
matmul_sharded = shard_map(
    matmul_fn, mesh, in_specs=(P("i", None), P(None, "j"), P("j")), out_specs=P("i", "j")
)
y = matmul_sharded(x_sharded, w_sharded, b_sharded)
print("Output shape", y.shape)
jax.debug.visualize_array_sharding(y)


Local x shape (48, 64)
Local w shape (64, 64)
Local b shape (64,)
Output shape (192, 128)


In [16]:
# @title parallel mean

z = jnp.arange(32).reshape(8, 4)
# z = jax.random.normal(jax.random.PRNGKey(1337), (batch_size, input_dim))
z_sharded = jax.device_put(z, NamedSharding(mesh, P("i", "j")))
jax.debug.visualize_array_sharding(z_sharded)


def dist_mean(inp: jax.Array) -> jax.Array:
    print("local inp shape", inp.shape)
    local_mean = jnp.mean(inp, axis=1, keepdims=True)
    # mean = jax.lax.pmean(local_mean, axis_name="j")
    j_size = jax.lax.psum(1, axis_name="j")
    global_mean_b = jax.lax.pmean(local_mean, axis_name="j")
    print("j size", j_size)
    return global_mean_b

sharded_dist_mean = shard_map(
    dist_mean, mesh, in_specs=P("i", "j"), out_specs=P("i", "j")
)

out_sharded = sharded_dist_mean(z_sharded)
jax.debug.visualize_array_sharding(out_sharded)

np_out = jax.device_get(out_sharded)
np_ref = jax.device_get(z).mean(1, keepdims=True)
print(f"diff abs max: {np.abs(np_out - np_ref).max()} | {out_sharded.shape=} {np_out.shape=} {np_ref.shape=}")

local inp shape (2, 2)
j size 2


diff abs max: 0.0 | out_sharded.shape=(8, 2) np_out.shape=(8, 2) np_ref.shape=(8, 1)


In [17]:
# @title parallel mean

z = jnp.arange(32).reshape(8, 4)
# z = jax.random.normal(jax.random.PRNGKey(1337), (batch_size, input_dim))
z_sharded = jax.device_put(z, NamedSharding(mesh, P("i", "j")))
jax.debug.visualize_array_sharding(z_sharded)


def dist_norm(inp: jax.Array) -> jax.Array:
    print("local inp shape", inp.shape)
    local_mean = jnp.mean(inp, axis=1, keepdims=True)
    global_mean_b = jax.lax.pmean(local_mean, axis_name="j")
    delta = inp - global_mean_b
    global_var_b = jax.lax.pmean(delta ** 2, axis_name="j")
    return delta / (global_var_b ** 0.5)

sharded_dist_norm = shard_map(
    dist_norm, mesh, in_specs=P("i", "j"), out_specs=P("i", "j")
)

out_sharded = sharded_dist_norm(z_sharded)
jax.debug.visualize_array_sharding(out_sharded)

np_out = jax.device_get(out_sharded)
np_z = jax.device_get(z)
np_ref_mean = np_z.mean(1, keepdims=True)
np_ref_delta = np_z - np_ref_mean
np_ref_var = (np_ref_delta ** 2).mean(1, keepdims=True)
np_ref = np_ref_delta /  (np_ref_var ** 0.5)
print(f"diff abs max: {np.abs(np_out - np_ref).max()} | {out_sharded.shape=} {np_out.shape=} {np_ref.shape=}")

local inp shape (2, 2)


diff abs max: 7.566918536205947e-08 | out_sharded.shape=(8, 4) np_out.shape=(8, 4) np_ref.shape=(8, 4)


In [36]:
dp_size = 4
tp_size = 2
mesh = Mesh(np.array(jax.devices()).reshape(dp_size, tp_size), axis_names=("dp", "tp"))
batch_size = 16
seq_len = 24
dim = 32
x = jax.random.normal(jax.random.PRNGKey(1337), shape=(batch_size, seq_len, dim))
w = jax.random.normal(jax.random.PRNGKey(130013), shape=(dim, dim))


@jax.jit
def unsharded_mul(z: jax.Array, w: jax.Array) -> jax.Array:
    return jnp.einsum("bld,de->ble", z, w)

y = unsharded_mul(x, w)
print(y.shape)


activation_spec = P("dp", None, None)
weight_spec = P("dp", "tp")
activation_sharding = NamedSharding(mesh, activation_spec)
weight_sharding = NamedSharding(mesh, weight_spec)

x_sharded = jax.device_put(x, activation_sharding)
w_sharded = jax.device_put(w, weight_sharding)

# jax.debug.visualize_array_sharding(x_sharded)
# jax.debug.visualize_array_sharding(w_sharded)

@functools.partial(shard_map, mesh=mesh, in_specs=(activation_spec, weight_spec), out_specs=activation_spec)
def sharded_mul(z_sharded: jax.Array, w_sharded: jax.Array) -> jax.Array:
    print("z sharded shape", z_sharded.shape, jax.typeof(z_sharded))
    w_full = jax.lax.all_gather(w_sharded, axis_name="dp", tiled=True, axis=0)
    print("w_full shape", w_full.shape, jax.typeof(w_full))

    local_y = jnp.einsum("bld,de->ble", z_sharded, w_full)
    print("local_y shape", local_y.shape, jax.typeof(local_y))

    global_invariant_y = jax.lax.all_gather_invariant(local_y, axis_name="tp", axis=2, tiled=True)
    # print("global_y shape", global_y.shape, jax.typeof(global_y))
    # global_invariant_y = jax.lax.all_gather_invariant(global_y, "tp")
    print("global_invariant_y shape", global_invariant_y.shape, jax.typeof(global_invariant_y))

    return global_invariant_y

y_sharded = sharded_mul(x_sharded, w_sharded)
np_x = jax.device_get(x)
np_w = jax.device_get(w)
np_ref = np.einsum("bld,de->ble", np_x, np_w)
np_y = jax.device_get(y)
np_y_sharded = jax.device_get(y_sharded)

np.testing.assert_allclose(np_ref, np_y, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(np_ref, np_y_sharded, rtol=1e-4, atol=1e-4)


(16, 24, 32)
z sharded shape (4, 24, 32) float32[4,24,32]{dp}
w_full shape (32, 16) float32[32,16]{dp,tp}
local_y shape (4, 24, 16) float32[4,24,16]{dp,tp}
global_invariant_y shape (4, 24, 32) float32[4,24,32]{dp}
