In [4]:
import functools

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P

import numpy as np

from sfp.utils import benchmark, numerics, profile, upload_to_gcs

jax.config.update('jax_num_cpu_devices', 4)

In [None]:
m, k, n = 2048, 2048, 1024

k1, k2 = jax.random.split(jax.random.key(0), 2)
inputs = jax.random.normal(k1, (m, k), dtype=jnp.bfloat16)
weights = jax.random.normal(k2, (k, n), dtype=jnp.bfloat16)

In [None]:
num_devices = jax.device_count()
mesh = jax.make_mesh((2, 2), ("x", "y"))
inp_sharding = NamedSharding(mesh, P('x', 'y'))
w_sharding = NamedSharding(mesh, P('x', None))

inputs = jax.device_put(inputs, inp_sharding)
weights = jax.device_put(weights, w_sharding)

  mesh = jax.make_mesh((2, 2), ("x", "y")) # (jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit)


inputs are size 2048, 2048 -> bf16:: 2 bytes * 2048 * 2048 = ~8MB
weights are size 2048, 1024 -> bf16:: 2 bytes * 2048 * 1024 = ~4MB

inputs sharded along x and y -> $Inp[I_{X}, J_{Y}]$

weights sharded along x -> $W[J_{X}, K]$

Each device has N elements per array:
  - inputs
    - (2048 / 2) * (2048 / 2) * 2bytes
    - ~2MB
  - weights
    - (2048 / 2) * 1024 * 2bytes
    - ~2MB

The contracting dimension is sharded in both inputs and weights, along different axes.
Need to handle that with collectives; AG/AR

In [11]:
jax.debug.visualize_array_sharding(inputs)

In [12]:
jax.debug.visualize_array_sharding(weights)

In [44]:
def jax_matmul(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.matmul(x, y)

@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul_1(input_shard: jax.Array, w_shard: jax.Array) -> jax.Array:
    # First we want to all_gather the data
    with jax.named_scope('all_gather(s)'):
        input_full = jax.lax.all_gather(input_shard, 'y', axis=1, tiled=True)
        w_full = jax.lax.all_gather(w_shard, 'x', axis=0, tiled=True) # gather w along x
    # Then we want to compute on the data
    # Since we did two full all gathers to start, regular GEMMs
    with jax.named_scope('dot'):
        local_out = input_full @ w_full
    return local_out

In [61]:
# numerics.compare(jax_matmul(inputs, weights), xla_matmul_1(inputs, weights), rtol=1e-2, atol=1e-2, region_grid=(2,2))
jitted = jax.jit(xla_matmul_1)
benchmark(jitted, inputs, weights)

BenchmarkResult (50 iters, 3 warmup)
  mean:        0.306 ms
  median:      0.302 ms
  stdev:       0.014 ms
  min:         0.289 ms
  max:         0.358 ms
  p95:         0.331 ms
  p99:         0.352 ms

In [59]:
out = jax_matmul(inputs, weights)
jax.debug.visualize_array_sharding(out)

In [None]:
jax_matmul_compiled = jax.jit(jax_matmul)
jmc = jax_matmul_compiled(inputs, weights)
jmc.block_until_ready()

with jax.profiler.trace('./traces'):
    result = jax_matmul_compiled(inputs, weights)
    result.block_until_ready()

"""
    %all-reduce = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-reduce(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %fusion), channel_id=2,
    replica_groups=[2,2]<=[4], use_global_device_ids=true, to_apply=%add.clone

    ^^ {1,0:T(8,128)(2,1)} ; column major -- probably RHS prepping for tranpose

    %fusion = bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done,
    bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %collective-permute-done), kind=kOutput, calls=%fused_computation

    %all-reduce = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-reduce(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %fusion), channel_id=2,
    replica_groups=[2,2]<=[4], use_global_device_ids=true, to_apply=%add.clone
"""


# xm1 = jax.jit(xla_matmul_1)
# result = xm1(inputs, weights)
# result.block_until_ready()

# with jax.profiler.trace(TRACES_DIR):
#     result = xm1(inputs, weights)
#     result.block_until_ready()

'\n    %all-reduce = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-reduce(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %fusion), channel_id=2,\n    replica_groups=[2,2]<=[4], use_global_device_ids=true, to_apply=%add.clone\n\n    ^^ {1,0:T(8,128)(2,1)} ; column major -- probably RHS prepping for tranpose\n\n    %fusion = bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done,\n    bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %collective-permute-done), kind=kOutput, calls=%fused_computation\n\n    %all-reduce = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-reduce(bf16[1024,1024]{1,0:T(8,128)(2,1)S(1)} %fusion), channel_id=2,\n    replica_groups=[2,2]<=[4], use_global_device_ids=true, to_apply=%add.clone\n'

In [16]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul_2(input_shard: jax.Array, weight_shard: jax.Array) -> jax.Array:
    """
    This time, we want to make the computation a little more efficient than
    stacking the two all gathers at the beginning of the kernel

    All Reduce at the end over partial sums
    """
    y_idx = jax.lax.axis_index('y')
    # All gather the weights over x so that each device contains full copy
    w_full = jax.lax.all_gather(weight_shard, 'x', axis=0, tiled=True)
    # Using the y-ring axis to determined which col stripe of weights to compute locally
    w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
    local_out = input_shard @ w_slice
    # All Reduce over the y-ring to accumulate partial results
    out = jax.lax.psum(local_out, 'y')
    return out

In [17]:
jnp.allclose(jax_matmul(inputs, weights), xla_matmul_2(inputs, weights))

Array(True, dtype=bool)

In [18]:
@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul_3(input_shard: jax.Array, weight_shard: jax.Array) -> jax.Array:
    """
    Use some higher precision numerics to demonstrate accumulation order (fp32)
    """
    y_idx = jax.lax.axis_index('y')
    w_full = jax.lax.all_gather(weight_shard, 'x', axis=0, tiled=True)
    # This shouldn't hardcode the dim shapes
    w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
    # This is probably overkill, think it might also incur perf penalty
    #   - Something in the docs about precision highest
    local_out = jax.lax.dot_general(
        input_shard, w_slice,
        dimension_numbers=(((1,), (0,)), ((), ())),
        precision=jax.lax.Precision.HIGHEST,
        preferred_element_type=jnp.float32,
    )
    out = jax.lax.psum(local_out, 'y')
    return out

In [19]:
jnp.allclose(xla_matmul_3(inputs, weights), jax_matmul(inputs, weights), rtol=1e-2, atol=1e-2)

Array(False, dtype=bool)

In [None]:
"""
Let's recall what we've learned so far --

When needed to perform an all gather on the reduction axis of our weights
to remove the sharding over X

Then, we compute local MatMuls (slicing out the appropriate data) between
the shard local inputs and the full weights
- Recall, we have the full weights after AG, so we need to slice out the
  appropriate chunks of W for our computation

These MatMuls are accumulators -> They finally need to be all reduced
over Y. The desired out sharding is ('x', None), so when we do an AR
over the Y axis, we are sharing the partial results
"""

"""
Let's sketch out the algorithm we care about --

We have 2 arrays distributed over 4 devices
  - 1/4 of inputs on each device
  - 1/2 of weights of each device

We want to efficiently compute this distributed matmul over the devices

We know that the contracting dims are sharded differently

So there will need to be some comms to unshard so that we have the
whole array in the right place. HOWEVER, we may also be able to get
away with ppermute to simple pass _results_ after compute is finished

Here are the kernels we may want to try
  - Simple matmul with lax collectives inserted in the right spots
  - MatMul with handrolled collectives (still AG to start, then AR)
  - ppermute
    - Issue async DMA
    - Run local compute; stash in accumulator
    - 
    - How much latency can we hide here?
      - If the DMAs are fast/slow?
      - How do we _reason_ about these tradeoffs

Start with 2x2 case; don't worry too much about abstracting things out
  - Then extend to larger configurations + abstractions
  - How do we think about the work we're doing?
  - Where are the opportunities to show different edge cases?
  - Where do our assumptions break down?
  - emit_pipeline
  - kernel schedule?
  - What happens when we run this on Trillium?
    - What changes?
"""

In [74]:
"""
Here's what happens:
- We have 2 arrays in HBM
- Need to all_gather weights over x, now each device has fully copy of x
- Then do local compute
- All Reduce the local compute to get the correct results
"""

#NOTE: THIS IS THE CASE WHERE WE SIMPLY REPLACE jax.dot/jax.matmul/x @ y in xla_matmul3
def simple_matmul(x_ref, y_ref, o_ref, scratch_ref, *, n_steps):
  # Zero scratch buffer
  @pl.when(pl.program_id(2) == 0)
  def _init_scratch():
    scratch_ref[...] = jnp.zeros_like(scratch_ref)

  # Compute dot
  scratch_ref[...] += jnp.dot(
    x_ref[...],
    y_ref[...],
    preferred_element_type=jnp.float32
  )

  # Flush to HBM
  @pl.when(pl.program_id(2) == n_steps - 1)
  def _flush_scratch():
    o_ref[...] = scratch_ref[...].astype(o_ref.dtype)


def make_matmul(
  x: jax.Array,
  y: jax.Array,
  *,
  bm: int = 128,
  bk: int = 128,
  bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape

  grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    grid=(m//bm, n//bn, k//bk),
    in_specs=[
      pl.BlockSpec((bm, bk), lambda i,j,k: (i, k)),
      pl.BlockSpec((bk, bn), lambda i,j,k: (k, j))
    ],
    out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
    scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)]
  )

  return pl.pallas_call(
    functools.partial(simple_matmul, n_steps=k//bk),
    grid_spec=grid_spec,
    # Made this float32 to appease the numerics gods
    out_shape=jax.ShapeDtypeStruct((m, n), dtype=jnp.bfloat16),
    interpret=False
  )(x, y)


def distributed_gemm_kernel1(inputs, weights):
  y_idx = jax.lax.axis_index('y')
  # AG
  w_full = jax.lax.all_gather(weights, 'x', axis=0, tiled=True)
  # jax.debug.print('w_full: {}', w_full.shape)
  # Slice out local arrays
  # TODO: again, fix these so they're not tied to the specific shapes
  w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
  # jax.debug.print('w_slice: {}', w_slice.shape)
  # jax.debug.print('input_shape: {}', inputs.shape)
  # We'll take the default tile sizes for now
  local_out = make_matmul(inputs, w_slice, bm=1024, bk=1024, bn=1024)
  # jax.debug.print('local_out: {}', local_out.shape)
  return jax.lax.psum(local_out, 'y')

In [75]:
dgk1 = jax.jit(
    jax.shard_map(
    distributed_gemm_kernel1,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
))

In [76]:
# jnp.allclose(dgk1(inputs, weights), jax_matmul(inputs, weights), rtol=1e-2, atol=1e-2)
benchmark(dgk1, inputs, weights)

BenchmarkResult (50 iters, 3 warmup)
  mean:        0.295 ms
  median:      0.294 ms
  stdev:       0.007 ms
  min:         0.278 ms
  max:         0.317 ms
  p95:         0.307 ms
  p99:         0.313 ms

In [78]:
ref = jax_matmul(inputs, weights)
test = dgk1(inputs, weights)

numerics.compare(ref, test, atol=1e-2, rtol=1e-2, region_grid=(2,2))

NumericsResult(PASS)
  shape:     (2048, 1024)
  max_diff:  0.000000
  mean_diff: 0.000000
  median:    0.000000
  % > 0.1: 0.00%
  worst at (0, 0): ref=-71.0000, test=-71.0000
  regions (2x2):
    [0,0]: mean=0.0000, %bad=0.0%
    [0,1]: mean=0.0000, %bad=0.0%
    [1,0]: mean=0.0000, %bad=0.0%
    [1,1]: mean=0.0000, %bad=0.0%

In [79]:
def all_gather_kernel_1D(
  input_ref, output_ref,
  local_send_sem, send_sem, recv_sem, 
):
  """
  input_ref: shard local data
  output_ref: out shard

  local_send_sem: allocates a semaphore for the local HBM copy
  send_sem: semaphore for the RDMA push
  recv_sem: semaphore for our local data
  """
  # TODO: Barrier

  pid = pl.program_id(0)
  

  shard_height = input_ref.shape[0]
  shard_width = input_ref.shape[1]

  # Get neighbors
  x_ring = jax.lax.axis_size('x')
  this_device_x = jax.lax.axis_index('x')
  this_device_y = jax.lax.axis_index('y')
  right_device_x = jax.lax.rem(this_device_x + 1, x_ring)

  # Recall: This is the _destination_ copy slot
  # TODO: Check and make sure these aren't garbage -> Think this might be correct on accident
  copy_slot_xright = this_device_x - pid
  copy_slot_xright = jax.lax.rem(copy_slot_xright + x_ring, x_ring)

  # We're just copying within our HBM to a bigger HBM memory
  @pl.when(pl.program_id(0) == 0)
  def _copy_local_to_local():
    local_hbm_copy = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[pl.ds(this_device_x * shard_height, shard_height), :],
      sem=local_send_sem
    )

    local_hbm_copy.start()
    local_hbm_copy.wait()

  right_dma = pltpu.make_async_remote_copy(
    src_ref=output_ref.at[pl.ds(copy_slot_xright * shard_height, shard_height), :],
    dst_ref=output_ref.at[pl.ds(copy_slot_xright * shard_height, shard_height), :],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=(right_device_x, this_device_y),
    device_id_type=pltpu.DeviceIdType.MESH,
  )

  right_dma.start()
  right_dma.wait()
  

grid_spec_ag1d = pltpu.PrefetchScalarGridSpec(
  num_scalar_prefetch=0,
  grid=(1,),
  in_specs=[
    # Our input reference is just our big tensor in HBM
    pl.BlockSpec(memory_space=pl.ANY)
  ],
  # Our output reference will be _another_ big tensor in HBM
  out_specs=pl.BlockSpec(memory_space=pl.ANY),
  scratch_shapes=(
    [pltpu.SemaphoreType.DMA] * 2 # local_copy_op, send_sem
    + [pltpu.SemaphoreType.DMA] * 1 # These are our recv_sems. For 2x2, we only need 1 of them
  )
)

out_shape_ag1d=jax.ShapeDtypeStruct((weights.shape), dtype=jnp.bfloat16)

def make_ag(x, interpret: None | bool = None):
  if not interpret:
    platform = jax.devices()[0].platform
    if platform == 'tpu':
      interpret = False
    else:
      # ignoring gpu for now
      interpret=True
  
  return pl.pallas_call(
    all_gather_kernel_1D,
    grid_spec=grid_spec_ag1d,
    out_shape=out_shape_ag1d,
    interpret=interpret
  )(x)

In [90]:
xla_ag = jax.jit(
    jax.shard_map(
        lambda x: jax.lax.all_gather(x, 'x', tiled=True),
        mesh=mesh, in_specs=P('x', None), out_specs=P('x', None))
)(weights)

ag1d = jax.jit(
  jax.shard_map(
      make_ag,
      mesh=mesh,
      in_specs=P('x', None),
      out_specs=P('x', None),
      check_vma=False
  )
)(weights)

In [None]:
"""
NOTE: When I was trying to directly use 
`output_ref[...] = local_hbm_ref[...] + output_ref[...]` it was causing an error.
> ValueError: Loads are only allowed on VMEM and SMEM references. ANY memory space can only be accessed using async_copy.

Reason:
output_ref is marked at pl.ANY. Only VMEM/SMEM are addressable by compute units (ref[...] compiles to loads/stores)
Off-chip memory (HBM; pl.ANY) accessible via DMAs, not loads/stores
"""

def all_reduce_kernel_1D(
    local_hbm_ref, output_ref,
    send_sem, recv_sem,
    local_scratch, recv_scratch, copy_sem
):
    """
    Right now, here's what we'll have --
    We have all-gathered the full weight tensor into each device's HBM
    THEN: we will have compute a GEMM over the device-local chunks
    We need to all-reduce over the Y AXIS at the end so that
    all the data is in the right place/on the right device

    The data needs to go from P('x', NONE) -> P('x', None)
    Basically -> We did GEMMs on inputs[0:M/2, 0:N/2] @ weights[0:M/2,N], ...

    We need to sum those row stipes over the y-axis, and everything will
    be good to go

    This will be _SLOW_ for now because of the HBM traffic
    """

    y_ring = jax.lax.axis_size('y')
    this_device_x = jax.lax.axis_index('x')
    this_device_y = jax.lax.axis_index('y')
    right_device_y = jax.lax.rem(this_device_y + 1, y_ring)

    local_copy = pltpu.make_async_copy(
       src_ref=local_hbm_ref,
       dst_ref=local_scratch,
       sem=copy_sem
    )
    local_copy.start()

    # This will copy our HBM tile into either:
    #  - Remote HBM Tile
    #    - Right now our GEMM works on HBM, so this will be easier temorarily
    #  - Remote VMEM tile (Memory pressure)
    right_dma = pltpu.make_async_remote_copy(
        src_ref=local_hbm_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=(this_device_x, right_device_y),
        device_id_type=pltpu.DeviceIdType.MESH
    )

    right_dma.start()
    local_copy.wait()
    right_dma.wait()

    # output_ref[...] = local_hbm_ref[...] + output_ref[...]
    # Add in VMEM, write back to HBM
    local_scratch[...] = local_scratch[...] + recv_scratch[...]

    out_copy = pltpu.make_async_copy(
          src_ref=local_scratch,
          dst_ref=output_ref,
          sem=copy_sem
      )
    out_copy.start()
    out_copy.wait()


grid_spec_ar1d=pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    grid=(1,),
    in_specs=[
        pl.BlockSpec(memory_space=pl.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pl.ANY),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3 # send, recv, copy
        + [pltpu.VMEM((1024, 1024), jnp.bfloat16)] # local scratch
        + [pltpu.VMEM((1024, 1024), jnp.bfloat16)] # recv scratch
    )
)

out_shape_ar1d = jax.ShapeDtypeStruct((weights.shape), dtype=jnp.bfloat16)

"""
Notice how you can also achieve the automatic HBM pipelining with blockspecs
as you might with a GEMM

def all_reduce_kernel_1D(
    local_ref,      # VMEM (auto-copied from HBM by Pallas)
    output_ref,     # VMEM (auto-copied to HBM by Pallas)
    recv_scratch,   # VMEM scratch
    send_sem, recv_sem
):
    y_ring = jax.lax.axis_size('y')
    this_device_x = jax.lax.axis_index('x')
    this_device_y = jax.lax.axis_index('y')
    right_device_y = jax.lax.rem(this_device_y + 1, y_ring)

    right_dma = pltpu.make_async_remote_copy(
        src_ref=local_ref,
        dst_ref=recv_scratch,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=(this_device_x, right_device_y),
        device_id_type=pltpu.DeviceIdType.MESH
    )
    right_dma.start()
    right_dma.wait()

    output_ref[...] = local_ref[...] + recv_scratch[...]
    """


def make_ar(input_array, interpret: None | bool = None):

    # ar_grid_spec = pltpu.PrefetchScalarGridSpec(
    # num_scalar_prefetch=0,
    # grid=(1,),
    # in_specs=[pl.BlockSpec((1024, 1024), lambda i: (0, 0))],
    # out_specs=pl.BlockSpec((1024, 1024), lambda i: (0, 0)),
    # scratch_shapes=(
    #     [pltpu.VMEM((1024, 1024), jnp.bfloat16)]  # recv_scratch
    #     + [pltpu.SemaphoreType.DMA] * 2
    #     )
    # )

    out_shape = jax.ShapeDtypeStruct(input_array.shape, input_array.dtype)

    return pl.pallas_call(
        all_reduce_kernel_1D,
        grid_spec=grid_spec_ar1d,
        out_shape=out_shape,
        # interpret=interpret
    )(input_array)

In [95]:
def slow_ag_gemm_ar_kernel(inputs, weights):
    y_idx = jax.lax.axis_index('y')
    a = make_ag(weights)
    a_slice = jax.lax.dynamic_slice(a, (y_idx * 1024, 0), (1024,1024))
    b = make_matmul(inputs, a_slice, bm=1024, bk=1024, bn=1024)
    c = make_ar(b)
    return c

In [96]:
sagak = jax.jit(
    jax.shard_map(
        slow_ag_gemm_ar_kernel,
        mesh=mesh,
        in_specs=(P('x', 'y'), P('x', None)),
        out_specs=P('x', None),
        check_vma=False
    )
)

sagak(inputs, weights)

Array([[-71, -42.5, -7.875, ..., 22.5, 85, 40],
       [8.875, -40, 21, ..., -27.75, 74.5, -51.5],
       [-36.75, -18.625, 90, ..., -70, -34, 5.625],
       ...,
       [-1.75, -9.625, 20.75, ..., 47.25, -47.5, 1.90625],
       [-7.21875, 30.375, -26, ..., 11, -95, -66.5],
       [-11, -24.375, 48, ..., -72.5, -17.625, 18.875]], dtype=bfloat16)

In [None]:
# NOTE: These benchmarks look dominated by overhead
# Need to rely on profile for accurate information
result = benchmark(sagak, inputs, weights)

Wall time: 16.6 ms
BenchmarkResult (50 iters, 3 warmup)
  mean:        0.289 ms
  median:      0.289 ms
  stdev:       0.008 ms
  min:         0.274 ms
  max:         0.308 ms
  p95:         0.302 ms
  p99:         0.306 ms


In [98]:
benchmark(jax_matmul, inputs, weights)

BenchmarkResult (50 iters, 3 warmup)
  mean:        0.287 ms
  median:      0.286 ms
  stdev:       0.009 ms
  min:         0.273 ms
  max:         0.325 ms
  p95:         0.300 ms
  p99:         0.320 ms

In [None]:
sagak_compiled = sagak.lower(inputs, weights).compile({'xla_enable_transpose_trace': True})
result = sagak_compiled(inputs, weights)
result.block_until_ready()

with jax.profiler.trace('./traces'):
    result = sagak_compiled(inputs, weights)
    result.block_until_ready()

"""
113 us vs 208 us, slow but workable
"""

In [None]:
"""
Basically --
The way we organize the grid for communications is a decision
  - Rings, ranges, etc.
  - This trades off bandwidth/latency


# Ring
grid = (ring_size - 1,)

# Recursive doubling
grid = (log2(ring_size),)

# Direct
grid = (1,)  # but more complex RDMA pattern
"""

In [None]:
"""
ValueError: The context mesh AbstractMesh('x': 2, 'y': 2, axis_types=(Manual, Manual),
  device_kind=TPU v5 lite, num_cores=1) should match the mesh passed to shard_map Mesh('x': 2,
   'y': 2, axis_types=(Auto, Auto))
"""

@functools.partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False
)
def xla_matmul_4(input: jax.Array, weight: jax.Array) -> jax.Array:
    y_idx = jax.lax.axis_index('y')
    w_full = mag1(weight)
    # Using the y-ring axis to determined which col stripe of weights to compute locally
    w_slice = jax.lax.dynamic_slice(w_full, (y_idx * 1024, 0), (1024, 1024))
    local_out = make_matmul(inputs, w_slice)
    # All Reduce over the y-ring to accumulate partial results
    out = jax.lax.psum(local_out, 'y')
    return out

In [None]:
"""
Here's something fun --

NotImplementedError: Meshes with more than 1 named
dimension not implemented in dma_start_p

Pallas's remote DMA primitives currently only support 1D meshes

https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/primitives.py
"""

In [None]:
# EXTEND THE KERNEL TO BE MORE GENERAL THAN 2x2 ; SCALE UP TO 4x4 GRID

In [11]:
"""
On each local device, we have 2 arrays sitting there -> inputs + weights
On kernel start:
  - Barrier sync to get everyone on the same stage
    - Might could relax this constraint?
  - Issue remote DMAs along the ... ring to all gather the weights
  - Once received
    - Compute local dots
    - Accumulate
  - Finally
    - All Reduce over the y ring to conform shmap shape


Additional notes:
  RECALL: n_hopes = grid (ring_len) - 1 -> I think this is just using one ICI link though
  What happens at n = 1? 2? 3?
  n=1:
    - You send right, receive left (this kernel is only looking at 1D case rn)
    - At (0,0): You send to (1,0) the chunk starting at 0 * row_height
    - You receive chunks from (1,0) and (3, 0)
    - You have [x, x, O, x] ; need idx 2 * row_height
    - r_neighbor needs -> [x, x, x, O]
  n=2:
    - This would be the epilogue already w/ 2 ICI links**
    - You send r_neighbor+1 chunk ; could probably push this up to initial mappings
"""

def all_gather_kernel_bidi(
  input_ref, output_ref,
  local_send_sem, send_sem_right, send_sem_left,
  recv_sem_right, recv_sem_left
):
  """
  input_ref: shard local data
  output_ref: out shard

  local_send_sem: allocates a semaphore for the local HBM copy
  send_sem: semaphore for the RDMA push
  recv_sem: semaphore for our local data
  """
  # TODO: Barrier

  pid = pl.program_id(0)
  
  # This should be baked into compiled artifact? Or is it runtime?
  # There has to be a prettier way to do this?
  shard_height = input_ref.shape[0]
  shard_width = input_ref.shape[1]

  this_device_x = jax.lax.axis_index('x')
  this_device_y = jax.lax.axis_index('y')
  x_ring = jax.lax.axis_size('x')
  right_device_x = jax.lax.rem(this_device_x + 1, x_ring)
  left_device_x = jax.lax.rem(this_device_x - 1 + x_ring, x_ring)

  # y_ring = jax.lax.axis_size('y')
  # right_device_y = jax.lax.rem(this_device_y + 1, y_ring)
  # left_device_y = jax.lax.rem(this_device_y - 1 + y_ring, y_ring)


  # This accounts for the offset when data is being sent both ways
  copy_slot_right = this_device_x - pid
  copy_slot_right = jax.lax.rem(copy_slot_right + x_ring, x_ring)
  copy_slot_left = jax.lax.rem(this_device_x + pid, x_ring)


  local_copy = local_hbm_copy = pltpu.make_async_copy(
    src_ref=input_ref,
    dst_ref=output_ref.at[pl.ds(this_device_x * shard_height, shard_height), :],
    sem=local_send_sem
)

  # PERFORM INITIAL ASYNC COPY FROM OUR HBM TO OUR HBM
  # We're just copying within our HBM to a bigger HBM memory
  # XLA liveness should handle malloc/free the _INPUT_ tensor once the AG completes
  @pl.when(pl.program_id(0) == 0)
  def _copy_local_to_local():
    # We can defer the wait until literally the very end of the kernel
    # TODO: We can defer this copy and make it free**
    # Need to add 2 more DMAs that send from input_ref -> remote output_ref at
    # the beginning, rather than output_ref -> output_ref
    # This is needlessly serial right now
    local_copy.start()
    local_copy.wait()

  right_dma = pltpu.make_async_remote_copy(
    # Next kernel iter depends on completion of left/right DMAs
    src_ref=output_ref.at[pl.ds(copy_slot_right * shard_height, shard_height), :],
    dst_ref=output_ref.at[pl.ds(copy_slot_right * shard_height, shard_height), :],
    send_sem=send_sem_right,
    # Imagine this as an array of semaphores-> [Dev1, Dev2, Dev3, Dev4, ..., DevN]
    # Signals the semaphore on the _destination_ device to signal
    recv_sem=recv_sem_right,
    device_id=(right_device_x, this_device_y),
    device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_dma = pltpu.make_async_remote_copy(
    src_ref=output_ref.at[pl.ds(copy_slot_left * shard_height, shard_height), :],
    dst_ref=output_ref.at[pl.ds(copy_slot_left * shard_height, shard_height), :],
    send_sem=send_sem_left,
    recv_sem=recv_sem_left,
    device_id=(left_device_x, this_device_y),
    device_id_type=pltpu.DeviceIdType.MESH
  )

  right_dma.start()
  left_dma.start()
  right_dma.wait()
  left_dma.wait()
  

grid_spec = pltpu.PrefetchScalarGridSpec(
  num_scalar_prefetch=0,
  # If you're using both ICI links:
  # grid=ceil((num_devices-1) / 2,) -> You're halving the num iters
  # grid=((ring_size-1) / 2)
  grid=(1,),
  in_specs=[
    # Our input reference is just our big tensor in HBM
    pl.BlockSpec(memory_space=pl.ANY)
  ],
  # Our output reference will be _another_ big tensor in HBM
  out_specs=pl.BlockSpec(memory_space=pl.ANY),
  # This will be an error if you need more semaphores for more neighbors
  scratch_shapes=(
    [pltpu.SemaphoreType.DMA] * 3 # local_copy_op, send_sem_left, send_sem_right
    + [pltpu.SemaphoreType.DMA] * 2 # These are our recv_sems (For 2x2, we only need 1 of them)
  )
)

out_shape=jax.ShapeDtypeStruct((weights.shape), dtype=jnp.bfloat16)


# TODO: should we parameterize this _here_; aka pass in the shard shape data here?
def make_ag(x, interpret: None | bool = None):
  if not interpret:
    platform = jax.devices()[0].platform
    if platform == 'tpu':
      interpret = False
    else:
      # ignoring gpu for now
      interpret=True
  
  return pl.pallas_call(
    all_gather_kernel_bidi,
    grid_spec=grid_spec,
    out_shape=out_shape,
    interpret=interpret
  )(x)

In [13]:
xla_ag = jax.jit(
    jax.shard_map(
        lambda x: jax.lax.all_gather(x, 'x', tiled=True),
        mesh=mesh, in_specs=P('x', None), out_specs=P('x', None))
)(weights)

agbd = jax.jit(
  jax.shard_map(
      make_ag,
      mesh=mesh,
      in_specs=P('x', None),
      out_specs=P('x', None),
      check_vma=False
  )
)(weights)

jnp.allclose(xla_ag, agbd)

Array(True, dtype=bool)

In [None]:
# NOTE: This is where we introduce semaphores as a way to mark data dependenices
# I.e., Results from one computation _depend_ on previous results**
# the semaphore lives on the device where wait() will be called**

def all_reduce_bidi(
    input_ref, output_ref,
    local_send_sem, send_sem_right, send_sem_left,
    recv_sem_right, recv_sem_left,
    hbm_scratch # we're sending into hbm scratch (for now)
):
  shard_height = input_ref.shape[0]
  shard_width = input_ref.shape[1]

  pid = pl.program_id(0)

  x_ring = jax.lax.axis_size('x')
  y_ring = jax.lax.axis_size('y')
  this_device_y = jax.lax.axis_index('y')
  this_device_x = jax.lax.axis_index('x')

  left_device_x = (this_device_x - 1) % x_ring
  right_device_y = (this_device_y + 1) % y_ring

  copy_slot_right = this_device_y - pid
  copy_slot_right = jax.lax.rem(copy_slot_right + y_ring, y_ring)
  copy_slot_left = jax.lax.rem(this_device_y + pid, y_ring)


  """
  - Send directly to remote VMEM
    - Overlap comms with local read latency
  - This is where you need to be careful about synchronizing**
    - This probably requires capacity semaphores
    - UNLESS you have enough space in VMEM to hold all of the data...
      - But even then you're just waiting to accumulate...
  - You can still pipeline this pretty aggressively with bidi
  - But this is also the copy/receiving concept**
    - Work from one slot, receive on another
  """

  right_dma = pltpu.make_async_remote_copy(
    src_ref=input_ref.at[:, pl.ds(copy_slot_right * shard_width)],
    dst_ref=hbm_scratch.at[:, pl.ds(copy_slot_right * shard_width)],
    send_sem=send_sem_right,
    recv_sem=,
    device_id=(this_device_x, this_device_y + 1),
    device_id_type=pltpu.DeviceIdType.MESH
  )

  left_dma = pltpu.make_async_remote_copy(
    src_ref=,
    dst_ref=,
    send_sem=,
    recv_sem=,
    device_id=(this_device_x, this_device_y - 1),
    device_id_type=pltpu.DeviceIdType.MESH
  )

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # TODO: this is wrong beyond 2x2
    grid=(1,),
    in_specs=[
        pl.BlockSpec()
    ],
    out_specs=(
        jax.ShapeDtypeStruct((weights.shape), dtype=jnp.bfloat16),
    ),
    scratch_shapes=(
      # pl.ANY()
    )
)

def make_ar_bidi(x):
    return pl.pallas_call(
        all_reduce_bidi,
        grid_spec=grid_spec,
        out_shape=out_shape,
        # compiler_params=pltpu.CompilerParams(
            # collective_id=0
        # )
    )(x)

In [None]:
# Added granularity of halving the outgoing ICI copies (L/R)
# NOTE: SHRINKING THE STEP SIZES (data sent) CAN BETTER OVERLAP??
# Except ICI is slowest comms channel ; this is good to test experimentally**

In [None]:
# Revisiting the matmul
# Might not need this except to explain tweaks from within kernel

In [None]:
"""
Instead of AG → slice → GEMM sequentially, overlap the weight
transfer with GEMM computation.

Each device (i, j) has:
  inputs[I_x, J_y]  (1024, 1024)
  weights[J_x, K]   (1024, 1024)

The correct partial product is: inputs[i,j] @ weights[j,:]
After AR over y: output[i,:] = sum_j inputs[i,j] @ weights[j,:]

For the AG over x (x_ring=2):
  x_idx == y_idx; local weights ARE weights[j,:], compute immediately
  x_idx != y_idx; need neighbor's weights, wait for RDMA

Start RDMA, run local GEMM (overlapping comm+compute on the devices that
already have the right chunk), then recompute with received weights only where needed.
"""

def _emit_gemm(x_ref, w_ref, o_ref, *, bm, bk, bn):
    """
    Emit a tiled GEMM pipeline
    All refs are HBM. emit_pipeline handles VMEM tiling + double-buffering.
    """
    m, k_dim = x_ref.shape
    _, n = w_ref.shape
    grid = (m // bm, n // bn, k_dim // bk)

    def body(x_vmem, w_vmem, o_vmem, accum):
        @pl.when(pl.program_id(2) == 0)
        def _():
            accum[...] = jnp.zeros_like(accum)

        accum[...] += jnp.dot(
            x_vmem[...], w_vmem[...],
            preferred_element_type=jnp.float32,
        )

        @pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
        def _():
            o_vmem[...] = accum[...].astype(o_vmem.dtype)

    @functools.partial(pl.run_scoped, accum=pltpu.VMEM((bm, bn), jnp.float32))
    def _(accum):
        pltpu.emit_pipeline(
            functools.partial(body, accum=accum),
            grid=grid,
            in_specs=[
                pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
            ],
            out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        )(x_ref, w_ref, o_ref)


def fused_ag_gemm_kernel(
    input_ref,          # HBM: inputs shard (m_local, k_local)
    weight_ref,         # HBM: weights shard (k_local, n)
    output_ref,         # HBM: GEMM output (m_local, n)
    recv_weight_ref,    # HBM: workspace for received weights (k_local, n)
    send_sem, recv_sem, # RDMA semaphores
):
    """
    Fused AG + GEMM via collective permute.

    Outer kernel manages weight exchange (RDMA along x-ring).
    Inner pipelines handle the tiled matmul.

    On half the devices (x_idx == y_idx) the local weights are
    already correct, so GEMM runs entirely overlapped with RDMA.
    On the other half (x_idx != y_idx) we wait for the remote
    chunk and then recompute — still a win over a blocking AG
    because the first GEMM warmed the MXU pipeline.
    """
    x_idx = jax.lax.axis_index('x')
    y_idx = jax.lax.axis_index('y')
    x_ring = jax.lax.axis_size('x')
    right_neighbor = jax.lax.rem(x_idx + 1, x_ring)

    BM, BK, BN = 128, 128, 128

    # --- Step 1: Kick off async weight exchange along x-ring ---
    # Each device sends its shard right and receives from its left neighbor
    rdma = pltpu.make_async_remote_copy(
        src_ref=weight_ref,
        dst_ref=recv_weight_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=(right_neighbor, y_idx),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    rdma.start()

    # --- Step 2: Local GEMM (overlaps RDMA) ---
    # When x_idx == y_idx this IS the correct result.
    # When x_idx != y_idx this is wasted compute — but the MXU
    # stays busy while ICI transfers the weight shard we need.
    _emit_gemm(input_ref, weight_ref, output_ref, bm=BM, bk=BK, bn=BN)

    # --- Step 3: Wait for remote weights ---
    rdma.wait()

    # --- Step 4: Recompute with correct weights where needed ---
    # NOTE: if @pl.when around emit_pipeline gives trouble, the
    # fallback is to always copy the correct chunk into
    # recv_weight_ref (local or received) and run one GEMM.
    @pl.when(x_idx != y_idx)
    def _():
        _emit_gemm(input_ref, recv_weight_ref, output_ref, bm=BM, bk=BK, bn=BN)


def make_fused_ag_gemm(inputs, weights):
    m_local, k_local = inputs.shape
    _, n = weights.shape

    grid_spec = pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        grid=(1,),
        in_specs=[
            pl.BlockSpec(memory_space=pl.ANY),  # inputs
            pl.BlockSpec(memory_space=pl.ANY),  # weights
        ],
        # Two outputs: real output + HBM workspace for received weights
        out_specs=[
            pl.BlockSpec(memory_space=pl.ANY),  # GEMM output
            pl.BlockSpec(memory_space=pl.ANY),  # recv weight buffer
        ],
        scratch_shapes=[
            [pltpu.SemaphoreType.DMA] * 2,  # send_sem, recv_sem
        ],
    )

    out_shape = [
        jax.ShapeDtypeStruct((m_local, n), inputs.dtype),   # GEMM output
        jax.ShapeDtypeStruct((k_local, n), weights.dtype),  # recv workspace
    ]

    results = pl.pallas_call(
        fused_ag_gemm_kernel,
        grid_spec=grid_spec,
        out_shape=out_shape,
    )(inputs, weights)

    return results[0]  # discard the workspace

In [None]:
# PUT IT TOGETHER
# INTERLEAVE THE COMPUTE with ppermute
# This is where the fun begins... How many versions can we cook up?
# We want blocks that are multiples of (8, 128) -> Benefit from larger block sizes
# Largely because we have FOUR MXUs on TPUv5e -> Only 2 for TPUv6e, but they're 256x256

def fused_ag_gemm_ar(inputs, weights):
    partial = make_fused_ag_gemm(inputs, weights)
    return make_ar(partial)

fused_fn = jax.jit(jax.shard_map(
    fused_ag_gemm_ar,
    mesh=mesh,
    in_specs=(P('x', 'y'), P('x', None)),
    out_specs=P('x', None),
    check_vma=False,
))

fused_fn(inputs, weights)

In [None]:
# NUMERICS CHECK
ref = jax_matmul(inputs, weights)
test = fused_fn(inputs, weights)
numerics.compare(ref, test, atol=1e-2, rtol=1e-2, region_grid=(2, 2))

In [None]:
# PERFORMANCE CHECK
benchmark(fused_fn, inputs, weights)

In [None]:
# PROFILING
# TODO: Don't forget to add annotations + compile with --xla-enable-transpose-trace

In [None]:
# NOTE: INCLUDE A VERSION WITH SOME XLA COMPILER FLAGS ENABLED TO TEST PERFORMANCE OR ASYNC ALL_GATHER, ETC.