In [1]:
import functools
import os

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P

In [2]:
jax.config.update('jax_num_cpu_devices', 4)

In [3]:
from pathlib import Path
import subprocess

def project_root() -> Path:
    return Path(subprocess.check_output(
        ['git', 'rev-parse', '--show-toplevel']
    ).decode().strip())

TRACES_DIR = project_root() / "traces"

In [4]:
m, k, n = 2048, 2048, 1024

k1, k2 = jax.random.split(jax.random.key(0), 2)
inputs = jax.random.normal(k1, (m, k), dtype=jnp.bfloat16)
weights = jax.random.normal(k2, (k, n), dtype=jnp.bfloat16)

In [5]:
mesh = jax.make_mesh((2, 2), ("x", "y"))
inp_sharding = jax.NamedSharding(mesh, P('x', 'y'))
w_sharding = jax.NamedSharding(mesh, P('x', None))

inputs = jax.device_put(inputs, inp_sharding)
weights = jax.device_put(weights, w_sharding)

  mesh = jax.make_mesh((2, 2), ("x", "y"))


inputs are size 2048, 2048 -> bf16:: 2 bytes * 2048 * 2048 = ~8MB
weights are size 2048, 1024 -> bf16:: 2 bytes * 2048 * 1024 = ~4MB

inputs sharded along x and y -> $Inp[I_{X}, J_{Y}]$

weights sharded along x -> $W[I_{X}, J]$

Each device has N elements per array:
  - inputs
    - (2048 / 2) * (2048 / 2) * 2bytes
    - ~2MB
  - weights
    - (2048 / 2) * 1024 * 2bytes
    - ~2MB

The contracting dimension is sharded in both inputs and weights, along different axes.
Need to handle that with collectives; AG/AR

In [6]:
jax.debug.visualize_array_sharding(inputs)

In [7]:
jax.debug.visualize_array_sharding(weights)

In [8]:
def basic_matmul(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.matmul(x, y)

out = basic_matmul(inputs, weights)
compiled = jax.jit(basic_matmul)

In [9]:
jax.debug.visualize_array_sharding(out)

In [10]:
result = compiled(inputs, weights)
result.block_until_ready()

with jax.profiler.trace(TRACES_DIR):
    result = compiled(inputs, weights)
    result.block_until_ready()

In [11]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul(input_shard: jax.Array, w_shard: jax.Array) -> jax.Array:
    # First we want to all_gather the data
    with jax.named_scope('all_gather(s)'):
        input_full = jax.lax.all_gather(input_shard, 'y', axis=1, tiled=True)
        w_full = jax.lax.all_gather(w_shard, 'x', axis=0, tiled=True) # gather w along x
    # Then we want to compute on the data
    with jax.named_scope('dot'):
        local_out = input_full @ w_full
    # Then we want to all reduce the data
    # with jax.named_scope('all_reduce'):
    #     out = jax.lax.psum(local_out, 'y')
    return local_out

In [12]:
# https://docs.jax.dev/en/latest/notebooks/shard_map.html
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(functools.partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(xla_matmul(inputs, weights), jnp.dot(inputs, weights))

True

In [13]:
gemm2_compiled = jax.jit(xla_matmul)
result = gemm2_compiled(inputs, weights)
result.block_until_ready()

with jax.profiler.trace(TRACES_DIR):
    result = gemm2_compiled(inputs, weights)
    result.block_until_ready()

In [14]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul2(input_shard: jax.Array, weight_shard: jax.Array) -> jax.Array:
    pass

In [None]:
#TODO: Add the Pallas impl with profiling