## Einsum

In [None]:
"""
einsum:
    declaration:
        A: [K, M]
        B: [K, N]
        Z: [M, N]
    expressions:
        - Z[m, n] = A[k, m] * B[k, n]
mapping:
    rank-order:
        A: [K, M]
        B: [K, N]
        Z: [M, N]
    partitioning:
        Z:
            K: [uniform_shape(K1), uniform_shape(128)]
            M: [uniform_shape(M1), uniform_shape(128)]
            N: [uniform_shape(N1), uniform_shape(512)]
    loop-order:
        Z: [M2, N2, K2, M1, N1, K1, M0, N0, K0]
    spacetime:
        Z:
            space: [M0, N0, K0]
            time: [M2, N2, K2, M1, N1, K1]
"""

## GEMM

In [None]:
import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl

@nki.jit
def nki_gemm(AH, BH, M1=1024, N1 = 1024, K1=1024):
    """
    NKI kernel to compute large matrix-vector multiplication.

    Args:
        AH: an input tensor of shape [K2, M2]
        BH: an input tensor of shape [K2, N2]
    Returns:
        ZH: the resulting output tensor of shape [M2, N2]
    """

    K2, M2 = AH.shape
    K2_, N2 = BH.shape
    assert K2 == K2_, "AH and BH must have the same contraction dimension"
    ZH = nl.ndarray((M2, N2), dtype=AH.dtype, buffer=nl.shared_hbm)

    M0 = nl.tile_size.gemm_stationary_fmax  # 128
    K0 = nl.tile_size.pmax  # 128
    N0 = nl.tile_size.gemm_moving_fmax  # 512

    # Blocking large tensors
    for m2 in nl.affine_range((M2+M1-1)//M1):
        for n2 in nl.affine_range((N2+N1-1)//N1):
            ZS = nl.ndarray((M1//M0, nl.par_dim(M0), N1),
                          dtype=AH.dtype,
                          buffer=nl.sbuf)
            for k2 in nl.affine_range((K2+K1-1)//K1):
                AS = nl.zeros((K1//K0, M1//M0, nl.par_dim(K0), M0),
                                dtype=AH.dtype,
                                buffer=nl.sbuf)
                BS = nl.zeros((K1//K0, N1//N0, nl.par_dim(K0), N0),
                                dtype=BH.dtype,
                                buffer=nl.sbuf)

                # Loading necessary tiles
                i_AS = nl.mgrid[0:K0, 0:M0]
                for m1 in nl.affine_range(M1//M0):
                    for k1 in nl.affine_range(K1//K0):
                        e1 = (K1 * k2 + K0 * k1) + i_AS.p
                        e2 = (M1 * m2 + M0 * m1) + i_AS.x
                        AS[k1, m1, i_AS.p, i_AS.x] = nl.load(
                            AH[e1, e2],
                            mask=((e1<K2) & (e2<M2)))

                i_BS = nl.mgrid[0:K0, 0:N0]
                for n1 in nl.affine_range(N1//N0):
                    for k1 in nl.affine_range(K1//K0):
                        e1 = (K1 * k2 + K0 * k1) + i_BS.p
                        e2 = (N1 * n2 + N0 * n1) + i_BS.x
                        BS[k1, n1, i_BS.p, i_BS.x] = nl.load(
                        BH[e1,e2], mask=((e1<K2) & (e2<N2)))

                # Perform matmul
                i_AS_mm = nl.mgrid[0:K0, 0:M0]
                i_BS_mm = nl.mgrid[0:K0, 0:N0]
                i_ZP_mm = nl.mgrid[0:M0, 0:N0]
                for m1 in nl.affine_range(M1//M0):
                    for n1 in nl.affine_range(N1//N0):
                        ZP = nl.zeros((M0, N0), dtype=nl.float32, buffer=nl.psum)
                        for k1 in nl.affine_range(K1//K0):
                            ZP[...] += nisa.nc_matmul(
                                AS[k1, m1, i_AS_mm.p, i_AS_mm.x],
                                BS[k1, n1, i_BS_mm.p, i_BS_mm.x])

                        # Accumulate on corresponding SBUF tile
                        ZS[m1, i_ZP_mm.p, n1 * N0 + i_ZP_mm.x] += ZP[i_ZP_mm.p, i_ZP_mm.x]

            # Copying the result from SBUF to HBM
            i_ZS = nl.mgrid[0:M0, 0:N1]
            for m1 in nl.affine_range(M1//M0):
                e1 = (M1 * m2 + M0 * m1) + i_ZS.p
                e2 = (N1 * n2) + i_ZS.x
                nl.store(ZH[e1, e2],
                         value=ZS[m1, i_ZS.p, i_ZS.x], mask=((e1<M2) & (e2<N2)))
                
    return ZH

## AWS GEMM

The program below is copied directly from the [NKI Matrix Multiplication tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/matrix_multiplication.html).

In [None]:
import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl

@nki.jit
def nki_matmul_fully_optimized_(
    lhsT,
    rhs,
    # Meta-parameters
    TILES_IN_BLOCK_M=16,
    TILES_IN_BLOCK_N=2,
    TILES_IN_BLOCK_K=8,
):
  """NKI kernel to compute a large matrix multiplication efficiently by
     blocking all dimensions and doing layout optimization.

  Args:
      lhsT: an input tensor of shape [K,M], where K is a multiple of 128 *
        TILES_IN_BLOCK_K and M is a multiple of 128 * TILES_IN_BLOCK_M.  It is the
        left-hand-side argument of the matrix multiplication, delivered transposed
        for optimal performance.
      rhs: an input tensor of shape [K,N],  where K is a multiple of 128 *
        TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N.  It is
        the right-hand-side argument of the matrix multiplication.
      TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
  Returns:
      result: the resulting output tensor of shape [M,N]
  """

  K, M = lhsT.shape
  K_, N = rhs.shape
  assert K == K_, "lhsT and rhs must have the same contraction dimension"
  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

  TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
  TILE_K = nl.tile_size.pmax  # 128
  TILE_N = nl.tile_size.gemm_moving_fmax  # 512

  BLOCK_M = TILE_M * TILES_IN_BLOCK_M
  BLOCK_N = TILE_N * TILES_IN_BLOCK_N
  BLOCK_K = TILE_K * TILES_IN_BLOCK_K

  # the size has to be multiple of block size
  assert M % BLOCK_M == 0
  assert N % BLOCK_N == 0
  assert K % BLOCK_K == 0

  NUM_BLOCK_M = M // BLOCK_M
  NUM_BLOCK_N = N // BLOCK_N
  NUM_BLOCK_K = K // BLOCK_K

  # Blocking N dimension (the RHS free dimension)
  for n in nl.affine_range(NUM_BLOCK_N):
    result_tiles = nl.zeros((NUM_BLOCK_M, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N,
                             nl.par_dim(TILE_M), TILE_N),
                            dtype=lhsT.dtype,
                            buffer=nl.sbuf)

    # Blocking K dimension (the contraction dimension)
    # Use `sequential_range` because we do not want the compiler to change this loop by, 
    # for example, vectorizing it
    for k in nl.sequential_range(NUM_BLOCK_K):
      # Loading tiles from rhs
      # setting the load tile to `TILE_K x BLOCK_SIZE_N` to optimize DMA performance
      i_rhs = nl.mgrid[0:TILE_K, 0:BLOCK_N]
      rhs_tiles = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), BLOCK_N),
                             dtype=rhs.dtype,
                             buffer=nl.sbuf)

      for bk_r in nl.affine_range(TILES_IN_BLOCK_K):
        rhs_tiles[bk_r, i_rhs.p, i_rhs.x] = nl.load(
            rhs[(TILES_IN_BLOCK_K * k + bk_r) * TILE_K + i_rhs.p,
                BLOCK_N * n + i_rhs.x])

      # Blocking M dimension (the LHS free dimension)
      for m in nl.affine_range(NUM_BLOCK_M):
        # Loading tiles from lhsT
        i_lhsT = nl.mgrid[0:TILE_K, 0:BLOCK_M]
        lhsT_tiles = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), BLOCK_M),
                                dtype=lhsT.dtype,
                                buffer=nl.sbuf)
        for bk_l in nl.affine_range(TILES_IN_BLOCK_K):
          lhsT_tiles[bk_l, i_lhsT.p, i_lhsT.x] = nl.load(
              lhsT[(TILES_IN_BLOCK_K * k + bk_l) * TILE_K + i_lhsT.p,
                   BLOCK_M * m + i_lhsT.x])

        # Do matmul with all tiles in the blocks
        i_lhsT_mm = nl.mgrid[0:TILE_K, 0:TILE_M]
        i_rhs_mm = nl.mgrid[0:TILE_K, 0:TILE_N]
        i_res_mm = nl.mgrid[0:TILE_M, 0:TILE_N]
        for bn in nl.affine_range(TILES_IN_BLOCK_N):
          for bm in nl.affine_range(TILES_IN_BLOCK_M):
            res_tile = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)

            for bk in nl.affine_range(TILES_IN_BLOCK_K):
              res_tile[...] += nisa.nc_matmul(
                  lhsT_tiles[bk, i_lhsT_mm.p, bm * TILE_M + i_lhsT_mm.x],
                  rhs_tiles[bk, i_rhs_mm.p, bn * TILE_N + i_rhs_mm.x])

            # Accumulate on corresponding SBUF tile
            result_tiles[m, bm, bn, i_res_mm.p,
                         i_res_mm.x] += res_tile[i_res_mm.p, i_res_mm.x]

    # Copying the result from SBUF to HBM
    for m in nl.affine_range(NUM_BLOCK_M):
      for bm in nl.affine_range(TILES_IN_BLOCK_M):
        i_res = nl.mgrid[0:TILE_K, 0:TILE_N]
        i_res_packed = nl.mgrid[0:TILE_K, 0:BLOCK_N]
        result_packed = nl.ndarray((TILE_K, BLOCK_N),
                                   dtype=result_tiles.dtype,
                                   buffer=nl.sbuf)

        # coalesce result tiles for better DMA performance
        for bn in nl.affine_range(TILES_IN_BLOCK_N):
          result_packed[i_res.p,
                        bn * TILE_N + i_res.x] = nl.copy(result_tiles[m, bm, bn,
                                                                      i_res.p,
                                                                      i_res.x])
        nl.store(result[(TILES_IN_BLOCK_M * m + bm) * TILE_K + i_res_packed.p,
                        BLOCK_N * n + i_res_packed.x],
                 value=result_packed[i_res_packed.p, i_res_packed.x])

  return result

## Test

In [None]:
import numpy as np
from gemm import nki_gemm

# Set matrix dimensions
# K = 1024
# M = 4096
# N = 2048

K = 1500
M = 4500
N = 2500

# Test random matrices
for i in range(5):
    A = np.random.rand(K, M).astype(np.float16)
    B = np.random.rand(K, N).astype(np.float16)
    result_nki = nki_gemm(A, B)
    result_np = np.dot(A.T, B)
    is_close = np.allclose(result_nki, result_np, rtol=1e-2, atol=1e-4)
    print("Result match: ", is_close)

## Benchmark

In [None]:
import neuronxcc.nki as nki
import numpy as np
from gemm import nki_gemm
from aws_gemm import nki_matmul_fully_optimized_

K = 16384
M = 8192
N = 16384

A = np.random.rand(K, M).astype(np.float16)
B = np.random.rand(K, N).astype(np.float16)

def benchmark_nki(nki_func):
    bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
    bench_func(A, B)
    latency_res = bench_func.benchmark_result.nc_latency
    p99 = latency_res.get_latency_percentile(99)
    print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))

print("Benchmarking aws_gemm")
benchmark_nki(nki_matmul_fully_optimized_)

print("Benchmarking gemm")
benchmark_nki(nki_gemm)

## Result

The latency of both gemm functions above, one following the TeAAL specification another taken directly from AWS's NKI tutorial, are comparable to one another. However, `nki_gemm` is generally slightly faster than `nki_matmul_fully_optimized_`.