## Kernel Description

Suppose we are given $A_{m}$ and $B_{m,n}$. The following Einsum seeks to represent the computation of 
$$T_{n} = A_{m} \cdot B_{m,n}$$
$$K_{n,k} = T_{n} \cdot A_{k}$$
$$Z_{n,k} = B_{n,k} - K_{n,k},$$ where the output is $Z$ up to transposition.

Note that for the Einsum below, we have that `C` is the tranpose of `A` and `D` is the transpose of `B`. Also, see that we must properly reduce along the `m` rank before we can process `K` and `Z`. This is therefore a two-pass algorithm, which implies that we need to load `A` and `B` into the engine twice.

## Einsum

In [None]:
"""
einsum:
    declaration:
        AH: [M1, M0]
        AS: [M1, M0, N2]
        BH: [M1, M0, N2, N1, N0]
        BS: [M1, M0, N2, N1, N0]
        CH: [K1, K0]
        CS: [N2, K1, K0]
        CTS: [N2, N0, K1, K0]
        DH: [N2, N1, N0, K1, K0]
        DS: [N2, N1, N0, K1, K0]
        TP: [M1, N2, N1, N0]
        TS: [N2, N1, N0]
        KS: [N2, N1, N0, K1, K0]
        ZS: [N2, N1, N0, K1, K0]
        ZH: [N2, N1, N0, K1, K0]
    expressions:
        - AS[m1, m0, n2] = AH[m1, m0]
        - BS[m1, m0, n2, n1, n0] = BH[m1, m0, n2, n1, n0]
        - CS[n2, k1, k0] = CH[k1, k0]
        - DS[n2, n1, n0, k1, k0] = DH[n2, n1, n0, k1, k0]
        - CTS[n2, n0, k1, k0] = CS[n2, k1, k0]
        - TP[m1, n2, n1, n0] = BS[m1, m0, n2, n1, n0] . AS[m1, m0, n2] ;; M mul V add
        - TS[n2, n1, n0] = TP[m1, n2, n1, n0] ;; V binary(add)
        - KS[n2, n1, n0, k1, k0] = CTS[n2, n0, k1, k0] . TS[n2, n1, n0] ;; M binary(multiply)
        - ZS[n2, n1, n0, k1, k0] = DS[n2, n1, n0, k1, k0] . KS[n2, n1, n0, k1, k0] ;; M binary(subtract)
        - ZH[n2, n1, n0, k1, k0] = ZS[n2, n1, n0, k1, k0]
    shape:
        N0: 128
mapping:
    partitioning:
        AS:
            M0: [uniform_shape(128), uniform_shape(1)]
        BS:
            M0: [uniform_shape(128)]
        CS:
            K0: [uniform_shape(512), uniform_shape(512)]
        CTS:
            K0: [uniform_shape(512), uniform_shape(512)]
        TP:
            M0: [uniform_shape(128)]
    loop-order:
        AS: [N2, M1, M02, M01, M00]
        BS: [N2, M1, M01, N1, M00, N0]
        CS: [N2, K1, K02, K01, K00]
        DS: [N2, K1, N1, N0, K0]
        CTS: [N2, K1, K02, K01, N0, K00]
        TP: [N2, M1, N1, M01, M00, N0]
        TS: [N2, M1, N1, N0]
        KS: [N2, K1, N1, N0, K0]
        ZS: [N2, K1, N1, N0, K0]
        ZH: [N2, K1, N1, N0, K0]
    spacetime:
        AS:
            space: [M01]
            time: [N2, M1, M02, M00]
        BS:
            space: [M00]
            time: [N2, M1, M01, N1, N0]
        CS:
            space: [K01]
            time: [N2, K1, K02, K00]
        DS:
            space: [N0]
            time: [N2, K1, N1, K0]
        CTS:
            space: [N0]
            time: [N2, K1, K02, K01, K00]
        TP:
            space: [M00, N0]
            time: [N2, M1, N1, M01]
        TS:
            space: [N0]
            time: [N2, M1, N1]
        KS:
            space: [N0]
            time: [N2, K1, N1, K0]
        ZS:
            space: [N0]
            time: [N2, K1, N1, K0]
        ZH:
            space: [N0]
            time: [N2, K1, N1, K0]
format:
    AH:
        rank-order:
          AS: [M1, M02, M01, M00]
        bit-width: 16
    AS:
        rank-order:
          AS: [N2, M1, M01, M02, M00]
          TP: [N2, M1, M01, M00]
        bit-width: 16
    BH:
        rank-order:
          BS: [M1, M01, M00, N2, N1, N0]
        bit-width: 16
    BS:
        rank-order:
          BS: [N2, M1, M00, M01, N1, N0]
          TP: [N2, M1, M01, M00, N1, N0]
        bit-width: 16
    CH:
        rank-order:
          CS: [K1, K02, K01, K00]
        bit-width: 16
    CS:
        rank-order:
          CS: [N2, K1, K01, K02, K00]
          CTS: [N2, K1, K01, K02, K00]
        bit-width: 16
    DH:
        rank-order:
          DS: [N2, N1, N0, K1, K0]
        bit-width: 16
    DS:
        rank-order:
            DS: [N2, K1, N0, N1, K0]
            ZS: [N2, K1, N0, N1, K0]
        bit-width: 16
    CTS:
        rank-order:
          CTS: [N2, K1, N0, K02, K00, K01]
          KS: [N2, K1, N0, K0]
        bit-width: 16
    TP:
        rank-order:
          TP: [N2, M1, N1, N0]
          TS: [N2, M1, N1, N0]
        bit-width: 32
    TS:
        rank-order:
          TS: [N2, N1, N0]
          KS: [N2, N0, N1]
        bit-width: 16
    KS:
        rank-order:
          KS: [N2, K1, N1, N0, K0]
          ZS: [N2, K1, N1, N0, K0]
        bit-width: 16
    ZS:
        rank-order:
          ZS: [N2, K1, N1, N0, K0]
          ZH: [N2, K1, N1, N0, K0]
        bit-width: 16
    ZH:
        rank-order:
          ZH: [N2, N1, N0, K1, K0]
        bit-width: 16
binding:
  fusion:
    - einsums: [AS, BS, CS, CTS, TP, TS, KS, ZH]
      ranks:
        - [AS.N2, BS.N2, CS.N2, DS.N2, CTS.N2, TP.N2, TS.N2, KS.N2, ZS.N2, ZH.N2]
        - [AS.M1, BS.M1, TP.M1, TS.M1]
        - [AS.M02]
        - [AS.M01]
        - [AS.M00]
        - [BS.M01]
        - [BS.N1]
        - [BS.M00]
        - [BS.N0]
        - [TP.N1, TS.N1]
        - [TP.M01]
        - [TP.M00]
        - [TP.N0]
        - [TS.N0]
        - [CS.K1, DS.K1, CTS.K1, KS.K1, ZS.K1, ZH.K1]
        - [CS.K02]
        - [CS.K01]
        - [CS.K00]
        - [CTS.K02]
        - [CTS.K01]
        - [CTS.N0]
        - [CTS.K00]
        - [DS.N1]
        - [DS.N0]
        - [DS.K0]
        - [KS.N1, ZS.N1, ZH.N1]
        - [KS.N0]
        - [KS.K0]
        - [ZS.N0]
        - [ZS.K0]
        - [ZH.N0]
        - [ZH.K0]
  engine:
    - einsum: AS
      component: DMAEngine
      instruction: Load
      ranks: [M01, M00]
    - einsum: BS
      component: DMAEngine
      instruction: Load
      ranks: [M00, N0]
    - einsum: CS
      component: DMAEngine
      instruction: Load
      ranks: [K01, K00]
    - einsum: DS
      component: DMAEngine
      instruction: Load
      ranks: [N0, K0]
    - einsum: CTS
      component: TensorEngine
      instruction: BroadcastTo
      ranks: [K01, N0, K00]
    - einsum: TP
      component: TensorEngine
      instruction: MatVec
      ranks: [M00, N0]
    - einsum: TS
      component: VectorEngine
      instruction: LoopReduce
      ranks: [N0, N1]
    - einsum: KS
      component: VectorEngine
      instruction: SimpleTensorScalar
      ranks: [N0, K0]
    - einsum: ZS
      component: VectorEngine
      instruction: SimpleTensorTensor
      ranks: [N0, K0]
    - einsum: ZH
      component: DMAEngine
      instruction: Store
      ranks: [N0, K0]
  tensor:
    - name: AH
      einsum: AS
      component: HBM
      ranks:
        unbound: []
        noncontiguous: [M1, M02, M01]
        contiguous: [M00]
    - name: AS
      einsum: AS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, M1]
        partition: [M01]
        free: [M02, M00]
    - name: AS
      einsum: TP
      component: SBUF
      ranks:
        unbound: [N2, M1]
        partition: [M00]
        free: [M01]
    - name: BH
      einsum: BS
      component: HBM
      ranks:
        unbound: []
        noncontiguous: [M1, M01, M00, N2, N1]
        contiguous: [N0]
    - name: BS
      einsum: BS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, M1]
        partition: [M00]
        free: [M01, N1, N0]
    - name: BS
      einsum: TP
      component: SBUF
      ranks:
        unbound: [N2, M1]
        partition: [M00]
        free: [M01, N1, N0]
    - name: CH
      einsum: CS
      component: HBM
      ranks:
        unbound: []
        noncontiguous: [K1, K02, K01]
        contiguous: [K00]
    - name: CS
      einsum: CS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, K1]
        partition: [K01]
        free: [K02, K00]
    - name: CS
      einsum: CTS
      component: SBUF
      ranks:
        unbound: [N2, K1]
        partition: [K01]
        free: [K02, K00]
    - name: DH
      einsum: DS
      component: HBM
      ranks:
        unbound: []
        noncontiguous: [N2, N1, N0, K1]
        contiguous: [K0]
    - name: DS
      einsum: DS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, K1]
        partition: [N0]
        free: [N1, K0]
    - name: DS
      einsum: ZS
      component: SBUF
      ranks:
        unbound: [N2, K1]
        partition: [N0]
        free: [N1, K0]
    - name: CTS
      einsum: CTS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, K1]
        partition: [N0]
        free: [K02, K00, K01]
    - name: CTS
      einsum: KS
      component: SBUF
      ranks:
        unbound: [N2, K1]
        partition: [N0]
        free: [K0]
    - name: TP
      einsum: TP
      component: PSUM
      constructor: zeros
      ranks:
        unbound: [N2, M1]
        partition: [N0]
        free: [N1]
    - name: TP
      einsum: TS
      component: PSUM
      ranks:
        unbound: [N2, M1]
        partition: [N0]
        free: [N1]
    - name: TS
      einsum: TS
      component: SBUF
      constructor: zeros
      ranks:
        unbound: [N2]
        partition: [N0]
        free: [N1]
    - name: TS
      einsum: KS
      component: SBUF
      ranks:
        unbound: [N2]
        partition: [N0]
        free: [N1]
    - name: KS
      einsum: KS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, K1, N1]
        partition: [N0]
        free: [K0]
    - name: KS
      einsum: ZS
      component: SBUF
      ranks:
        unbound: [N2, K1, N1]
        partition: [N0]
        free: [K0]
    - name: ZS
      einsum: ZS
      component: SBUF
      constructor: empty
      ranks:
        unbound: [N2, K1, N1]
        partition: [N0]
        free: [K0]
    - name: ZS
      einsum: ZH
      component: SBUF
      ranks:
        unbound: [N2, K1, N1]
        partition: [N0]
        free: [K0]
    - name: ZH
      einsum: ZH
      component: HBM
      constructor: empty
      ranks:
        unbound: []
        noncontiguous: [N2, N1, N0]
        contiguous: [K1, K0]
options:
  kernel-name: qr_update
  cascade-outputs: [ZH]
"""

## Generated NKI function

In [None]:
import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.isa as ni
import neuronxcc.nki.language as nl

@nki.jit
def qr_update(AH, BH, CH, DH, M1, M0, N2, N1, K1, K0):
    ZH = nl.ndarray((*[N2, N1, 128], *[K1, K0]), dtype=nl.float16, buffer=nl.shared_hbm)
    for as_n2 in nl.affine_range(N2):
        TS = nl.zeros((nl.par_dim(*[128]), *[N1]), dtype=nl.float16, buffer=nl.sbuf)
        for as_m1 in nl.affine_range(M1):
            AS = nl.ndarray((nl.par_dim(*[128]), *[M0 // 128, 1]), dtype=nl.float16, buffer=nl.sbuf)
            for as_m02 in nl.affine_range(M0 // 128):
                AS[:, as_m02, :] = nl.load(AH[as_m1, as_m02, :, :])
            # Edit: Flatten N1 and N0
            BS = nl.ndarray((nl.par_dim(*[128]), *[M0 // 128, N1 * 128]), dtype=nl.float16, buffer=nl.sbuf)
            for bs_m01 in nl.affine_range(M0 // 128):
                # Edit: Flatten N1 and N0
                BS[:, bs_m01, :] = nl.load(BH[as_m1, bs_m01, :, as_n2, :])
            TP = nl.zeros((nl.par_dim(*[128]), *[N1]), dtype=nl.float32, buffer=nl.psum)
            for tp_n1 in nl.affine_range(N1):
                for tp_m01 in nl.affine_range(M0 // 128):
                    AS_for_TP = AS.reshape((128, M0 // 128))
                    # Edit: Unflatten for TP
                    BS_for_TP = BS.reshape((128, M0 // 128, N1, 128))
                    TP[:, tp_n1] += ni.nc_matmul(BS_for_TP[:, tp_m01, tp_n1, :], AS_for_TP[:, tp_m01])
                TS[:, tp_n1] = nl.loop_reduce(TP[:, tp_n1], np.add, loop_indices=[as_m1], dtype=nl.float16)
        for cs_k1 in nl.affine_range(K1):
            CS = nl.ndarray((nl.par_dim(*[1]), *[K0 // 512, 512]), dtype=nl.float16, buffer=nl.sbuf)
            for cs_k02 in nl.affine_range(K0 // 512):
                CS[:, cs_k02, :] = nl.load(CH[cs_k1, cs_k02, :, :])
            DS = nl.ndarray((nl.par_dim(*[128]), *[N1, K0]), dtype=nl.float16, buffer=nl.sbuf)
            for ds_n1 in nl.affine_range(N1):
                DS[:, ds_n1, :] = nl.load(DH[as_n2, ds_n1, :, cs_k1, :])
            CTS = nl.ndarray((nl.par_dim(*[128]), *[K0 // 512, 512, 1]), dtype=nl.float16, buffer=nl.sbuf)
            for cts_k02 in nl.affine_range(K0 // 512):
                CTS[:, cts_k02, :, :] = nl.broadcast_to(CS[:, cts_k02, :], shape=tuple([1, 128, 512][1:]))
            for ks_n1 in nl.affine_range(N1):
                KS = nl.ndarray((nl.par_dim(*[128]), *[K0]), dtype=nl.float16, buffer=nl.sbuf)
                CTS_for_KS = CTS.reshape((128, K0))
                KS[:, :] = ni.tensor_scalar(CTS_for_KS[:, :], np.multiply, TS[:, ks_n1])
                ZS = nl.ndarray((nl.par_dim(*[128]), *[K0]), dtype=nl.float16, buffer=nl.sbuf)
                ZS[:, :] = ni.tensor_tensor(DS[:, ks_n1, :], KS[:, :], np.subtract)
                nl.store(ZH[as_n2, ks_n1, :, cs_k1, :], value=ZS[:, :])
    return ZH

## Test

In [None]:
from neuronxcc import nki
import numpy as np
from qr_update import qr_update

M = 12288
N = 22528

a = np.random.rand(M,1).astype(np.float16)
b = np.random.rand(M,N).astype(np.float16)
d = np.random.rand(N,M).astype(np.float16)

A = a.reshape(6, 16, 128, 1)
B = b.reshape(6, 16, 128, 11, 16 * 128)
C = a.reshape(6, 4, 1, 512)
D = d.reshape(11, 16, 128, 6, 2048)

expected = d.T - a @ (a.T @ b)
result_nki = qr_update(A, B, C, D, M1=6, M0=2048, N2=11, N1=16, K1=6, K0=2048)
result_nki = result_nki.reshape(22528, 12288).T
print(expected)
print(result_nki)
is_close = np.allclose(result_nki, expected, rtol=1e-2, atol=1e-3)
print("Result match: ", is_close)

## Benchmark

In [None]:
import neuronxcc.nki as nki
import numpy as np
from qr_update import qr_update

M = 12288
N = 22528

a = np.random.rand(M,1).astype(np.float16)
b = np.random.rand(M,N).astype(np.float16)
d = np.random.rand(N,M).astype(np.float16)

A = a.reshape(6, 16, 128, 1)
B = b.reshape(6, 16, 128, 11, 16 * 128)
C = a.reshape(6, 4, 1, 512)
D = d.reshape(11, 16, 128, 6, 2048)

def benchmark_nki(nki_func):
    bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
    bench_func(A, B, C, D, M1=6, M0=2048, N2=11, N1=16, K1=6, K0=2048)
    latency_res = bench_func.benchmark_result.nc_latency
    p99 = latency_res.get_latency_percentile(99)
    print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))

print("Benchmarking matmul_outer_diff")
benchmark_nki(qr_update)

## Profile

In [None]:
from neuronxcc import nki
from pathlib import Path
import numpy as np
from qr_update import qr_update

WORKING_DIRECTORY = Path.cwd()

M = 12288
N = 22528

a = np.random.rand(M,1).astype(np.float16)
b = np.random.rand(M,N).astype(np.float16)
d = np.random.rand(N,M).astype(np.float16)

A = a.reshape(6, 16, 128, 1)
B = b.reshape(6, 16, 128, 11, 16 * 128)
C = a.reshape(6, 4, 1, 512)
D = d.reshape(11, 16, 128, 6, 2048)

profile_func = nki.profile(working_directory=WORKING_DIRECTORY, save_neff_name='file.neff', 
                           save_trace_name='profile.ntff', profile_nth=2)(qr_update)
profile_func(A, B, C, D, M1=6, M0=2048, N2=11, N1=16, K1=6, K0=2048)