In [1]:
import os
import functools
from typing import Optional

import numpy as np
import jax
import jax.numpy as jnp

In [2]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

In [3]:
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,)))

In [4]:
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(2, 1))
jax.debug.visualize_array_sharding(y)

In [5]:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)

In [6]:
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()

The slowest run took 5.94 times longer than the fastest. This could mean that an intermediate result is being cached.
4.34 ms ± 4.26 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [7]:
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

1.05 ms ± 7.85 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [8]:
import jax

x = jax.random.normal(jax.random.key(0), (8192, 8192))

In [9]:
jax.debug.visualize_array_sharding(x)

In [10]:
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((2,))
devices

array([cuda(id=0), cuda(id=1)], dtype=object)

In [15]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(x)
sharding

AssertionError: (1, 2)

In [12]:
sharding.reshape(1, 2)
print(sharding)

sharding.reshape(2, 1)
print(sharding)

PositionalSharding([{GPU 0} {GPU 1}], shape=(2,))
PositionalSharding([{GPU 0} {GPU 1}], shape=(2,))


In [13]:
x.sharding

PositionalSharding([[{GPU 0}]
                    [{GPU 1}]], shape=(2, 1))

In [14]:
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)

AssertionError: (1, 2)