This notebook intends to replicate all the examples from the [pjit blog](https://irhum.github.io/blog/pjit/) from [`Irhum`](https://github.com/irhum)

### Setup

In [1]:
import jax
import jax.numpy as jnp

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding, NamedSharding, Mesh

from jax.sharding import PartitionSpec as P

if len(jax.local_devices()) < 8:
    raise Exception("Notebook requires 8 devices to run")

from jax_smi import initialise_tracking
initialise_tracking()

In [2]:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
key = jax.random.PRNGKey(0)

### [Case 1: Inner Axes](https://irhum.github.io/blog/pjit/#case-1-inner-axes)

In [4]:
import jax.numpy as jnp

# create a vector of 1, 2, 3, 4
v = jnp.arange(1, 5)
# repeat the vector 16 times along the first axis
x = jnp.repeat(v[:, None], 16, axis=1)

x = jnp.asarray(x, dtype="f2")
y = jnp.copy(x).T

In [5]:
x = jax.device_put(x, NamedSharding(mesh, P(None, 'a')))
jax.debug.visualize_array_sharding(x)
y = jax.device_put(y, NamedSharding(mesh, P('a', None)))
jax.debug.visualize_array_sharding(y)

In [6]:
z = x@y
jax.debug.visualize_array_sharding(z)
print(z)

[[ 16.  32.  48.  64.]
 [ 32.  64.  96. 128.]
 [ 48.  96. 144. 192.]
 [ 64. 128. 192. 256.]]


In [22]:
print(jnp.matmul.lower(x, y).compile().as_text())

HloModule jit_matmul, is_scheduled=true, entry_computation_layout={(f32[4,4]{1,0:T(4,128)}, f32[4,4]{1,0:T(4,128)})->f32[4,4]{1,0:T(4,128)}}, allow_spmd_sharding_propagation_to_output={true}

%add (x: f32[], y: f32[]) -> f32[] {
  %y = f32[]{:T(256)} parameter(1)
  %x = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %x, f32[]{:T(256)} %y)
}

%bitcast_fusion (bf16input: f32[4,4]) -> f32[4,4] {
  %bf16input = f32[4,4]{1,0:T(4,128)} parameter(0)
  ROOT %bitcast = f32[4,4]{1,0:T(4,128)} bitcast(f32[4,4]{1,0:T(4,128)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[4,4]) -> f32[4,4] {
  %bf16input.1 = f32[4,4]{1,0:T(4,128)} parameter(0)
  ROOT %bitcast.1 = f32[4,4]{1,0:T(4,128)} bitcast(f32[4,4]{1,0:T(4,128)} %bf16input.1)
}

%fused_computation (param_0: f32[4,4], param_1: f32[4,4]) -> f32[4,4] {
  %param_0 = f32[4,4]{1,0:T(4,128)} parameter(0)
  %fusion.1 = f32[4,4]{1,0:T(4,128)} fusion(f32[4,4]{1,0:T(4,128)} %param_0), kind=kLoop, calls=%bitcast_fusion
  %p

In [7]:
%timeit -n 10 -r 10 jnp.matmul(x, y).block_until_ready()

442 µs ± 88.5 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### [Case 1B: Mesh-axes mismatch](https://irhum.github.io/blog/pjit/#case-1b-mesh-axes-mismatch)

In [4]:
import jax.numpy as jnp

# create a vector of 1, 2, 3, 4
v = jnp.arange(1, 5)
# repeat the vector 16 times along the first axis
x = jnp.repeat(v[:, None], 16, axis=1)

x = jnp.asarray(x, dtype="f2")
y = jnp.copy(x).T

In [5]:
x = jax.device_put(x, NamedSharding(mesh, P(None, 'a')))
jax.debug.visualize_array_sharding(x)
y = jax.device_put(y, NamedSharding(mesh, P('b', None)))
jax.debug.visualize_array_sharding(y)

In [6]:
z = x@y
jax.debug.visualize_array_sharding(z)

In [7]:
z

Array([[ 16.,  32.,  48.,  64.],
       [ 32.,  64.,  96., 128.],
       [ 48.,  96., 144., 192.],
       [ 64., 128., 192., 256.]], dtype=float16)

In [8]:
print(jnp.matmul.lower(x, y).compile().as_text())

HloModule jit_matmul, is_scheduled=true, entry_computation_layout={(f16[4,4]{1,0:T(4,128)(2,1)}, f16[8,4]{0,1:T(4,128)(2,1)})->f16[4,4]{1,0:T(4,128)(2,1)}}, allow_spmd_sharding_propagation_to_output={true}

%bitcast_fusion (bf16input: f16[16,4]) -> f16[16,4] {
  %bf16input = f16[16,4]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.7 = f16[16,4]{1,0:T(8,128)(2,1)} bitcast(f16[16,4]{1,0:T(8,128)(2,1)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f16[16,4]) -> f16[16,4] {
  %bf16input.1 = f16[16,4]{0,1:T(4,128)(2,1)} parameter(0)
  ROOT %bitcast.8 = f16[16,4]{0,1:T(4,128)(2,1)} bitcast(f16[16,4]{0,1:T(4,128)(2,1)} %bf16input.1)
}

%fused_computation (param_0: f16[16,4], param_1: f16[16,4]) -> f16[4,4] {
  %param_0 = f16[16,4]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.1 = f16[16,4]{1,0:T(8,128)(2,1)} fusion(f16[16,4]{1,0:T(8,128)(2,1)} %param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = f16[16,4]{0,1:T(4,128)(2,1)} parameter(1)
  %fusion.2 = f16[16,4]{0,1:T(4,128)(2,1)} fusion(f16

In [9]:
%timeit -n 10 -r 10 jnp.matmul(x, y).block_until_ready()

503 µs ± 73.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### [Case 2: Outer Axes](https://irhum.github.io/blog/pjit/#case-2-outer-axes)

In [11]:
import jax.numpy as jnp

# create a vector of 1, 2, 3, 4
v = jnp.arange(1, 5)
# repeat the vector 16 times along the first axis
x = jnp.repeat(v[:, None], 16, axis=1)

x = jnp.asarray(x, dtype="f2")
y = jnp.copy(x).T

In [12]:
x = jax.device_put(x, NamedSharding(mesh, P('b', 'a')))
jax.debug.visualize_array_sharding(x, max_width=120)

In [13]:
y = jax.device_put(y, NamedSharding(mesh, P('b', None)))
jax.debug.visualize_array_sharding(y)

In [14]:
z = x@y
jax.debug.visualize_array_sharding(z)
print(z)

[[ 16.  32.  48.  64.]
 [ 32.  64.  96. 128.]
 [ 48.  96. 144. 192.]
 [ 64. 128. 192. 256.]]


In [16]:
print(jnp.matmul.lower(x, y).compile().as_text())

HloModule jit_matmul, is_scheduled=true, entry_computation_layout={(f16[2,4]{1,0:T(4,128)(2,1)}, f16[8,4]{0,1:T(4,128)(2,1)})->f16[2,4]{1,0:T(4,128)(2,1)}}, allow_spmd_sharding_propagation_to_output={true}

%bitcast_fusion (bf16input: f16[16,2]) -> f16[16,2] {
  %bf16input = f16[16,2]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.7 = f16[16,2]{1,0:T(8,128)(2,1)} bitcast(f16[16,2]{1,0:T(8,128)(2,1)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f16[16,4]) -> f16[16,4] {
  %bf16input.1 = f16[16,4]{0,1:T(4,128)(2,1)} parameter(0)
  ROOT %bitcast.8 = f16[16,4]{0,1:T(4,128)(2,1)} bitcast(f16[16,4]{0,1:T(4,128)(2,1)} %bf16input.1)
}

%fused_computation (param_0: f16[16,2], param_1: f16[16,4]) -> f16[2,4] {
  %param_0 = f16[16,2]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.1 = f16[16,2]{1,0:T(8,128)(2,1)} fusion(f16[16,2]{1,0:T(8,128)(2,1)} %param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = f16[16,4]{0,1:T(4,128)(2,1)} parameter(1)
  %fusion.2 = f16[16,4]{0,1:T(4,128)(2,1)} fusion(f16

In [15]:
%timeit -n 10 -r 10 jnp.matmul(x, y).block_until_ready()

463 µs ± 85.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### [Full Sharding](https://irhum.github.io/blog/pjit/#full-sharding)

In [4]:
import jax.numpy as jnp

# create a vector of 1, 2, 3, 4
v = jnp.arange(1, 5)
# repeat the vector 16 times along the first axis
x = jnp.repeat(v[:, None], 16, axis=1)

x = jnp.asarray(x, dtype="f2")
y = jnp.copy(x).T

In [5]:
x = jax.device_put(x, NamedSharding(mesh, P('b', 'a')))
jax.debug.visualize_array_sharding(x, max_width=120)

In [6]:
y = jax.device_put(y, NamedSharding(mesh, P('b', 'a')))
jax.debug.visualize_array_sharding(y, max_width=120)

In [7]:
z = x@y
jax.debug.visualize_array_sharding(z)
print(z)

[[ 16.  32.  48.  64.]
 [ 32.  64.  96. 128.]
 [ 48.  96. 144. 192.]
 [ 64. 128. 192. 256.]]


In [9]:
print(jnp.matmul.lower(x, y).compile().as_text())

HloModule jit_matmul, is_scheduled=true, entry_computation_layout={(f16[2,4]{1,0:T(4,128)(2,1)}, f16[8,1]{0,1:T(4,128)(2,1)})->f16[2,1]{1,0:T(4,128)(2,1)}}, allow_spmd_sharding_propagation_to_output={true}

%all-gather.3.reduce_sub_computation (lhs: f16[], rhs: f16[]) -> f16[] {
  %lhs = f16[] parameter(0)
  %rhs = f16[] parameter(1)
  ROOT %add.1 = f16[] add(f16[] %lhs, f16[] %rhs)
}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f16[16,2], param_1.3: f16[16]) -> f32[2] {
  %param_0.2 = f16[16,2]{0,1:T(4,128)(2,1)} parameter(0)
  %param_1.3 = f16[16]{0:T(512)(128)(2,1)} parameter(1)
  %broadcast.3 = f16[16,2]{0,1:T(4,128)(2,1)} broadcast(f16[16]{0:T(512)(128)(2,1)} %param_1.3), dimensions={0}, metadata={op_name="jit(matmul)/jit(main)/dot_general[dime

In [10]:
%timeit -n 10 -r 10 jnp.matmul(x, y).block_until_ready()

469 µs ± 72.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### [Sharding: GSPMD-style](https://irhum.github.io/blog/pjit/#sharding-gspmd-style)