<a href="https://colab.research.google.com/github/HarounH/smol/blob/main/rl/distributed_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os

# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags
import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

print(jax.devices())

a = jnp.arange(16)
print("Array", a)
print("Device", a.device)
print("Sharding", a.sharding)

mesh = Mesh(np.array(jax.devices()), ("i",))
print(mesh)

sharding = NamedSharding(
    mesh,
    P("i"),
)

a_sharded = jax.device_put(a, sharding)
print("Sharded array", a_sharded)
print("Device", a_sharded.devices())
print("Sharding", a_sharded.sharding)
jax.debug.visualize_array_sharding(a_sharded)


[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Array [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
Device TFRT_CPU_0
Sharding SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Mesh('i': 8, axis_types=(Auto,))
Sharded array [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
Device {CpuDevice(id=6), CpuDevice(id=0), CpuDevice(id=3), CpuDevice(id=1), CpuDevice(id=7), CpuDevice(id=5), CpuDevice(id=2), CpuDevice(id=4)}
Sharding NamedSharding(mesh=Mesh('i': 8, axis_types=(Auto,)), spec=PartitionSpec('i',), memory_kind=device)
