In [1]:
import functools
import os

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P

import numpy as np

In [2]:
jax.config.update('jax_num_cpu_devices', 4)

In [3]:
from pathlib import Path
import subprocess

def project_root() -> Path:
    return Path(subprocess.check_output(
        ['git', 'rev-parse', '--show-toplevel']
    ).decode().strip())

TRACES_DIR = project_root() / "traces"

In [4]:
m, k, n = 2048, 2048, 1024

k1, k2 = jax.random.split(jax.random.key(0), 2)
inputs = jax.random.normal(k1, (m, k), dtype=jnp.bfloat16)
weights = jax.random.normal(k2, (k, n), dtype=jnp.bfloat16)

In [5]:
num_devices = jax.device_count()
mesh = jax.make_mesh((2, 2), ("x", "y"))
inp_sharding = jax.NamedSharding(mesh, P('x', 'y'))
w_sharding = jax.NamedSharding(mesh, P('x', None))
o_sharding = jax.NamedSharding(mesh, P('x', None))

inputs = jax.device_put(inputs, inp_sharding)
weights = jax.device_put(weights, w_sharding)

  mesh = jax.make_mesh((2, 2), ("x", "y"))


inputs are size 2048, 2048 -> bf16:: 2 bytes * 2048 * 2048 = ~8MB
weights are size 2048, 1024 -> bf16:: 2 bytes * 2048 * 1024 = ~4MB

inputs sharded along x and y -> $Inp[I_{X}, J_{Y}]$

weights sharded along x -> $W[J_{X}, K]$

Each device has N elements per array:
  - inputs
    - (2048 / 2) * (2048 / 2) * 2bytes
    - ~2MB
  - weights
    - (2048 / 2) * 1024 * 2bytes
    - ~2MB

The contracting dimension is sharded in both inputs and weights, along different axes.
Need to handle that with collectives; AG/AR

In [6]:
jax.debug.visualize_array_sharding(inputs)

In [7]:
jax.debug.visualize_array_sharding(weights)

In [6]:
def basic_matmul(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.matmul(x, y)

out = basic_matmul(inputs, weights)
compiled = jax.jit(basic_matmul)

In [7]:
jax.debug.visualize_array_sharding(out)

In [10]:
result = compiled(inputs, weights)
result.block_until_ready()

with jax.profiler.trace(TRACES_DIR):
    result = compiled(inputs, weights)
    result.block_until_ready()

In [8]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul(input_shard: jax.Array, w_shard: jax.Array) -> jax.Array:
    # First we want to all_gather the data
    with jax.named_scope('all_gather(s)'):
        input_full = jax.lax.all_gather(input_shard, 'y', axis=1, tiled=True)
        w_full = jax.lax.all_gather(w_shard, 'x', axis=0, tiled=True) # gather w along x
    # Then we want to compute on the data
    with jax.named_scope('dot'):
        local_out = input_full @ w_full
    # Then we want to all reduce the data
    # with jax.named_scope('all_reduce'):
    #     out = jax.lax.psum(local_out, 'y')
    return local_out

In [11]:
# https://docs.jax.dev/en/latest/notebooks/shard_map.html
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(functools.partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(xla_matmul(inputs, weights), jnp.dot(inputs, weights))

True

In [13]:
gemm2_compiled = jax.jit(xla_matmul)
result = gemm2_compiled(inputs, weights)
result.block_until_ready()

with jax.profiler.trace(TRACES_DIR):
    result = gemm2_compiled(inputs, weights)
    result.block_until_ready()

In [9]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul2(input_shard: jax.Array, weight_shard: jax.Array) -> jax.Array:
    """
    This time, we want to make the computation a little more efficient than
    stacking the two all gathers are the beginning of the kernel
    """
    y_idx = jax.lax.axis_index('y')
    # All gather the weights over x so that each device contains full copy
    w_full = jax.lax.all_gather(weight_shard, 'x', axis=0, tiled=True)
    # Using the y-ring axis to determined which col stripe of weights to compute locally
    w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
    local_out = input_shard @ w_slice
    # All Reduce over the y-ring to accumulate partial results
    out = jax.lax.psum(local_out, 'y')
    return out

In [13]:
a = xla_matmul2(inputs, weights)

In [14]:
allclose(basic_matmul(inputs, weights), xla_matmul2(inputs, weights))

False

In [10]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul3(input_shard: jax.Array, weight_shard: jax.Array) -> jax.Array:
    """
    Use some higher precision numerics to accomodate accumulation order
    """
    y_idx = jax.lax.axis_index('y')
    w_full = jax.lax.all_gather(weight_shard, 'x', axis=0, tiled=True)
    # This shouldn't hardcode the dim shapes
    w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
    # This is probably overkill, think it might also incur perf penalty
    #   - Something in the docs about precision highest
    local_out = jax.lax.dot_general(
        input_shard, w_slice,
        dimension_numbers=(((1,), (0,)), ((), ())),
        precision=jax.lax.Precision.HIGHEST,
        preferred_element_type=jnp.float32,
    )
    out = jax.lax.psum(local_out, 'y')
    return out

In [11]:
jnp.allclose(xla_matmul3(inputs, weights), basic_matmul(inputs, weights), rtol=1e-2, atol=1e-2)

Array(True, dtype=bool)

In [None]:
"""
Let's recall what we've learned so far --

When needed to perform an all gather on the reduction axis of our weights
to remove the sharding over X

Then, we compute local MatMuls (slicing out the appropriate data) between
the shard local inputs and the full weights
- Recall, we have the full weights after AG, so we need to slice out the
  appropriate chunks of W for our computation

These MatMuls are accumulators -> They finally need to be all reduced
over Y. The desired out sharding is ('x', None), so when we do an AR
over the Y axis, we are sharing the partial results
"""

"""
Let's sketch out the algorithm we care about --

We have 2 arrays distributed over 4 devices
  - 1/4 of inputs on each device
  - 1/2 of weights of each device

We want to efficiently compute this distributed matmul over the devices

We know that the contracting dims are sharded differently

So there will need to be some comms to unshard so that we have the
whole array in the right place. HOWEVER, we may also be able to get
away with ppermute to simple pass _results_ after compute is finished

Here are the kernels we may want to try
  - Simple matmul with lax collectives inserted in the right spots
  - MatMul with handrolled collectives (still AG to start, then AR)
  - ppermute
    - Issue async DMA
    - Run local compute; stash in accumulator
    - 
    - How much latency can we hide here?
      - If the DMAs are fast/slow?
      - How do we _reason_ about these tradeoffs

Start with 2x2 case; don't worry too much about abstracting things out
  - Then extend to larger configurations + abstractions
  - How do we think about the work we're doing?
  - Where are the opportunities to show different edge cases?
  - Where do our assumptions break down?
  - emit_pipeline
  - kernel schedule?
  - What happens when we run this on Trillium?
    - What changes?
"""

In [12]:
"""
Here's what happens:
- We have 2 arrays in HBM
- Need to all_gather weights over x, now each device has fully copy of x
- Then do local compute
- All Reduce the local compute to get the correct results
"""

#NOTE: THIS IS THE CASE WHERE WE SIMPLY REPLACE jax.dot/jax.matmul/x @ y in xla_matmul3
def simple_matmul(x_ref, y_ref, o_ref, scratch_ref, *, n_steps):
  # Zero scratch buffer
  @pl.when(pl.program_id(2) == 0)
  def _init_scratch():
    scratch_ref[...] = jnp.zeros_like(scratch_ref)

  # Compute dot
  scratch_ref[...] += jnp.dot(
    x_ref[...],
    y_ref[...],
    preferred_element_type=jnp.float32
  )

  # Flush to HBM
  @pl.when(pl.program_id(2) == n_steps - 1)
  def _flush_scratch():
    o_ref[...] = scratch_ref[...].astype(o_ref.dtype)


def make_matmul(
  x: jax.Array,
  y: jax.Array,
  *,
  bm: int = 128,
  bk: int = 128,
  bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape

  grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    grid=(m//bm, n//bn, k//bk),
    in_specs=[
      pl.BlockSpec((bm, bk), lambda i,j,k: (i, k)),
      pl.BlockSpec((bk, bn), lambda i,j,k: (k, j))
    ],
    out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
    scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)]
  )

  return pl.pallas_call(
    functools.partial(simple_matmul, n_steps=k//bk),
    grid_spec=grid_spec,
    # Made this float32 to appease the numerics gods
    out_shape=jax.ShapeDtypeStruct((m, n), dtype=jnp.float32),
    interpret=True
  )(x, y)


def distributed_gemm_kernel1(inputs, weights):
  y_idx = jax.lax.axis_index('y')
  # AG
  w_full = jax.lax.all_gather(weights, 'x', axis=0, tiled=True)
  # jax.debug.print('w_full: {}', w_full.shape)
  # Slice out local arrays
  # TODO: again, fix these so they're not tied to the specific shapes
  w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
  # jax.debug.print('w_slice: {}', w_slice.shape)
  # jax.debug.print('input_shape: {}', inputs.shape)
  # We'll take the default tile sizes for now
  local_out = make_matmul(inputs, w_slice)
  # jax.debug.print('local_out: {}', local_out.shape)
  return jax.lax.psum(local_out, 'y')

In [13]:
dgk1 = jax.jit(
    jax.shard_map(
    distributed_gemm_kernel1,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
))

In [14]:
jnp.allclose(dgk1(inputs, weights), basic_matmul(inputs, weights), rtol=1e-2, atol=1e-2)

Array(True, dtype=bool)

In [20]:
ref = basic_matmul(inputs, weights)
test = dgk1(inputs, weights)
diff = jnp.abs(ref - test)

print(f"max diff: {jnp.max(diff)}")
print(f"mean diff: {jnp.mean(diff)}")
print(f"median diff: {jnp.median(diff)}")
print(f"% > 0.1: {100 * jnp.mean(diff > 0.1):.2f}%")

# Location of errors
bad_mask = diff > 0.1
bad_rows = jnp.any(bad_mask, axis=1)
bad_cols = jnp.any(bad_mask, axis=0)
print(f"rows with errors: {jnp.sum(bad_rows)}  {ref.shape[0]}")
print(f"cols with errors: {jnp.sum(bad_cols)} {ref.shape[1]}")

# Worst error location
bad_idx = jnp.argmax(diff)
i, j = bad_idx // ref.shape[1], bad_idx % ref.shape[1]
print(f"worst error at [{i}, {j}]: ref={ref[i,j]}, test={test[i,j]}")

# Check quadrants (shard boundaries)
quadrants = [
    ("top-left", diff[:1024, :512]),
    ("top-right", diff[:1024, 512:]),
    ("bottom-left", diff[1024:, :512]),
    ("bottom-right", diff[1024:, 512:]),
]
for name, q in quadrants:
    mean_val = float(jnp.mean(q))
    pct_bad = float(100 * jnp.mean(q > 0.1))
    print(f"{name}: mean={mean_val:.4f}, % bad={pct_bad:.1f}%")

max diff: 0.5
mean diff: 0.05057927221059799
median diff: 0.031246185302734375
% > 0.1: 15.81%
rows with errors: 2048  2048
cols with errors: 1024 1024
worst error at [875, 791]: ref=128, test=128.5
top-left: mean=0.0507, % bad=15.9%
top-right: mean=0.0506, % bad=15.8%
bottom-left: mean=0.0505, % bad=15.8%
bottom-right: mean=0.0505, % bad=15.7%


In [21]:
import time
import statistics

for _ in range(3):
    dgk1(inputs, weights).block_until_ready()
print("done")

# Benchmark
print("Running benchmark...", end=" ", flush=True)
times = []
for _ in range(50):
    start = time.perf_counter()
    dgk1(inputs, weights)
    end = time.perf_counter()
    times.append((end - start) * 1000)  # Convert to ms
print("done")
print()

# Statistics
mean_ms = statistics.mean(times)
median_ms = statistics.median(times)
stdev_ms = statistics.stdev(times) if len(times) > 1 else 0
min_ms = min(times)
max_ms = max(times)
p95 = np.percentile(times, 95)
p99 = np.percentile(times, 99)

print("Results:")
print(f"  Mean:   {mean_ms:>10.3f} ms")
print(f"  Median: {median_ms:>10.3f} ms")
print(f"  Stdev:  {stdev_ms:>10.3f} ms")
print(f"  Min:    {min_ms:>10.3f} ms")
print(f"  P95:    {p95:>10.3f} ms")
print(f"  P99:    {p99:>10.3f} ms")

done
Running benchmark... done

Results:
  Mean:      199.003 ms
  Median:      1.530 ms
  Stdev:     272.175 ms
  Min:         0.022 ms
  P95:       650.769 ms
  P99:       722.780 ms


In [None]:
# I think this one should have the matmul done after the full AG
# Then we can interleave with ppermute

# FIRST: Work out the AllGather
# SECOND: Work out the MatMul
# FINAL: Work out the AR

"""
On each local device, we have 2 arrays sitting there -> inputs + weights
On kernel start:
  - Barrier sync to get everyone on the same stage
    - Might could relax this constraint?
  - Issue remote DMAs along the ... ring to all gather the weights
  - Once received
    - Compute local dots
    - Accumulate
  - Finally
    - All Reduce over the y ring to conform shmap shape
"""

def all_gather_kernel_1D(
  input_ref, output_ref,
  local_send_sem, send_sem, recv_sem, 
  # ...
):
  """
  These two refs are intended to decouple what's going on (that's the best you can do?)

  The input ref is the local ref, the output ref is the ref that will be
  sending/receiving data from our neighboring devices
  """
  #TODO: Barrier
  
  # Get tensor dims/sizes
  # This should be baked into compiled artifact? Or is it runtime?
  # There has to be a prettier way to do this?
  shard_height = input_ref.shape[0]
  shard_width = input_ref.shape[1]

  # Get neighbors
  # Map to some position in [(0,0), (1,0), (0,1), (1,1)] along x
  # left_dev = jax.lax.rem(device_id - 1, x_ring)
  this_device_x = jax.lax.axis_index('x')
  this_device_y = jax.lax.axis_index('y')
  x_ring = jax.lax.axis_size('x')
  y_ring = jax.lax.axis_size('y')
  right_dev = jax.lax.rem(this_device_x + 1, x_ring)

  # Hard coding 2: where 2 is supposed to be the ring_length
  neighbor_x = (this_device_x + 1) % x_ring
  # neighbor_linear = neighbor_x * y_ring + this_device_y

  # PERFORM INITIAL ASYNC COPY FROM OUR HBM TO OUR HBM
  # @pl.when(pl.program_id(0) == 0) -> We're just copying within our HBM to a bigger HBM memory
  # XLA liveness should handle malloc/free the _INPUT_ tensor once the AG completes
  #   def _copy_local_to_local 
  local_hbm_copy = pltpu.make_async_copy(
    src_ref=input_ref,
    dst_ref=output_ref.at[pl.ds(this_device_x * shard_height, shard_height), :],
    sem=local_send_sem
  )

  # We can defer the wait until literally the very end of the kernel
  local_hbm_copy.start()

  # Issue RDMA
  #NOTE: This is buggy; this won't work for axis lengths > 2
  # Logic would keep writing to the same location in the neighbor's out_ref*
  right_dma = pltpu.make_async_remote_copy(
    src_ref=input_ref.at[...],
    dst_ref=output_ref.at[pl.ds(this_device_x * shard_height, shard_height), :],
    send_sem=send_sem,
    recv_sem=recv_sem,
    #NOTE: device_id has to match the mesh specs
    # Since we're in a 2x2 grid -> Need to communication _which_ links we're using
    # device_id=(right_dev,),
    device_id_type=pltpu.DeviceIdType.MESH,
    device_id=(right_dev, this_device_y),
    # device_id=(right_dev),
    # device_id=neighbor_linear,
    # device_id_type=pltpu.DeviceIdType.LOGICAL,
  )

  # Wait on RDMA (send/recv)
  right_dma.start()
  right_dma.wait()
  local_hbm_copy.wait()

grid_spec = pltpu.PrefetchScalarGridSpec(
  num_scalar_prefetch=0,
  # This logic is wrong -> this is not a "line" of devices, but a 2x2 grid
  # We only want to iterate along the number of devices PER RING - 1 times
  grid=(1,), # Could we move this elsewhere?
  in_specs=[
    # Our input reference is just our big tensor in HBM
    pl.BlockSpec(memory_space=pl.ANY)
  ],
  # Our output reference will be _another_ big tensor in HBM
  out_specs=pl.BlockSpec(memory_space=pl.ANY),
  # This will be an error if you need more semaphores for more neighbors
  scratch_shapes=(
    [pltpu.SemaphoreType.DMA] * 3
  )
)

out_shape=jax.ShapeDtypeStruct((weights.shape), dtype=jnp.bfloat16)

def make_ag(x):
  # TODO: should we parameterize this _here_; aka pass in the shard shape data here?
  return pl.pallas_call(
    all_gather_kernel_1D,
    grid_spec=grid_spec,
    out_shape=out_shape,
    interpret=True
  )(x)

In [None]:
"""
Basically --
The way we organize the grid for communications is a decision
  - Rings, ranges, etc.
  - This trades off bandwidth/latency


# Ring
grid = (ring_size - 1,)

# Recursive doubling
grid = (log2(ring_size),)

# Direct
grid = (1,)  # but more complex RDMA pattern
"""

In [67]:
def xla_allgather(weights):
    w_full = jax.lax.all_gather(weights, 'x', axis=0, tiled=True)
    return w_full

res = jax.jit(
    jax.shard_map(
        xla_allgather,
        mesh=mesh,
        in_specs=P('x', None),
        out_specs=P(),
        check_vma=False
    )
)(weights)

In [107]:
res2 = jax.jit(
    jax.shard_map(
        make_ag,
        mesh=mesh,
        in_specs=P('x', None),
        out_specs=P(None, None),
        check_vma=False
    )
)(weights)

NotImplementedError: Meshes with more than 1 named dimension not implemented in dma_start_p

In [None]:
"""
Here's something fun --

NotImplementedError: Meshes with more than 1 named
dimension not implemented in dma_start_p

Pallas's remote DMA primitives currently only support 1D meshes. That's a real limitation.

https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/primitives.py
"""

In [None]:
def no_shmap_ag_kernel(
    input_ref, output_ref,
    local_copy_sem, send_sem, recv_sem
):
  # Get axis detail
  # Make Local HBM <--> HBM copy
  # Make Remote HBM <--> HBM copy

  pass

# def no_shmap_ag(input):
#     return pl.pallas_call(
#         kernel,
#         grid_spec=(),
#         out_shape=(),
        
#     )(input)

In [None]:
def all_gather_kernel_2D(
  x
):
  # Barrier
  
  # Get neighbors

  # Issue RDMA

  # Wait on RDMA (send/recv)
  # NOTE: need to change the semaphores/refs to handle reads/writes from multiple neighbors
  # NOTE: In a 2x2 grid, there's only one neighbor**
  # Does this need to be accounted for in code? -> For loop type construct to handle this  

In [None]:
def new_all_gather():
    pass


res = jax.jit(
    jax.shard_map(
        new_all_gather,
        mesh=mesh,
        in_specs=(P(), P()),
        out_specs=P(),
        # check_vma=False
    )
)

In [None]:
# jax.shard_map(
#     ...
# )(inputs, weights)

# emit_pipeline version

In [None]:
# PUT IT TOGETHER
# INTERLEAVE THE COMPUTE with ppermute
# emit_pipeline??

In [None]:
# NUMERICS CHECK

# ACCURACY CHECK

In [None]:
# PERFORMANCE CHECK

In [None]:
# PROFILING

In [None]:
# EXTEND THE KERNEL TO BE MORE GENERAL THAN 2x2
# REPEAT CHECKS
# SCALE UP TO 4x4 GRID