### Example Small Matrix Kernel Generation Using Exo ###

In [1]:
# Imports
import driver

In [2]:
# Boilerplate
kernel_generator = driver.MicrokernelGenerator()

First, we have the base SGEMM procedure, shown below

In [3]:
kernel_generator.sgemm_base

```python
def SGEMM(M: size, N: size, K: size, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM,
          C: f32[M, N] @ DRAM):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1
    for i in par(0, M):
        for j in par(0, N):
            for k in par(0, K):
                C[i, j] += A[i, k] * B[k, j]

```

Now, we will generate a 32x32 matmul kernel with a 4x16 microkernel that uses ARM Neon intrinsics.

In [4]:
# Neon without edge cases
sgemm_exo_neon = kernel_generator.generate_exo_sgemm(driver.NeonMachine, 4, 16, 32, 32, 32)
sgemm_exo_neon

```python
def SGEMM(M: size, N: size, K: size, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM,
          C: f32[M, N] @ DRAM):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1
    for ji in par(0, N / 16):
        B_strip: f32[K, 16] @ DRAM
        for i0 in seq(0, K):
            for i1 in seq(0, 16):
                B_strip[i0, i1] = B[i0, i1 + 16 * ji]
        for ii in par(0, M / 4):
            microkernel_4x16_0(K, A[4 * ii:4 * ii + 4, 0:K],
                               B_strip[0:K, 0:16], C[4 * ii:4 * ii + 4,
                                                     16 * ji:16 * ji + 16])
        for io in par(0, M % 4):
            for jo in par(0, 16):
                for k in par(0, K):
                    C[io + M / 4 * 4, 16 * ji +
                      jo] += A[io + M / 4 * 4, k] * B[k, 16 * ji + jo]
    for ii in par(0, M / 4):
        for io in par(0, 4):
            for jo in par(0, N % 16):
                for k in par(0, K):
                    C[4 * ii + io,
                      jo + N / 16 * 16] += A[4 * ii + io,
                                             k] * B[k, jo + N / 16 * 16]
    for io in par(0, M % 4):
        for jo in par(0, N % 16):
            for k in par(0, K):
                C[io + M / 4 * 4, jo +
                  N / 16 * 16] += A[io + M / 4 * 4, k] * B[k, jo + N / 16 * 16]

```

Here's what the microkernel itself looks like:

In [5]:
kernel_generator.microkernels[0]

```python
def microkernel_4x16_0(K: size, A: [f32][4, K] @ DRAM, B: [f32][K, 16] @ DRAM,
                       C: [f32][4, 16] @ DRAM):
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1
    C_reg: R[4, 4, 4] @ Neon4f
    for i in par(0, 4):
        for jo in par(0, 4):
            neon_vld_4xf32(C_reg[i, jo, 0:4], C[i, 4 * jo:4 * jo + 4])
    for k in seq(0, K):
        for i in par(0, 4):
            A_vec: R[4] @ Neon4f
            neon_broadcast_4xf32(A_vec[0:4], A[i, k:k + 1])
            for jo in par(0, 4):
                B_vec: R[4] @ Neon4f
                neon_vld_4xf32(B_vec[0:4], B[k, 4 * jo:4 * jo + 4])
                neon_vfmadd_4xf32_4xf32(C_reg[i, jo, 0:4], A_vec[0:4],
                                        B_vec[0:4])
    for i in par(0, 4):
        for jo in par(0, 4):
            neon_vst_4xf32(C[i, 4 * jo:4 * jo + 4], C_reg[i, jo, 0:4])

```

Now, we will generate a 31x31 matmul kernel with the same 4x16 microkernel. Since neither dimension of the microkernel evenly divides the problem size, there will be edge cases.

Three separate kernels will be generated to handle the edge cases:

In [6]:
edge_exo_sgemm_neon = kernel_generator.generate_exo_sgemm(driver.NeonMachine, 4, 16, 31, 31, 31)
edge_exo_sgemm_neon

```python
def SGEMM(M: size, N: size, K: size, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM,
          C: f32[M, N] @ DRAM):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1
    for ji in par(0, N / 16):
        B_strip: f32[K, 16] @ DRAM
        for i0 in seq(0, K):
            for i1 in seq(0, 16):
                B_strip[i0, i1] = B[i0, i1 + 16 * ji]
        for ii in par(0, M / 4):
            microkernel_4x16_1(K, A[4 * ii:4 * ii + 4, 0:K],
                               B_strip[0:K, 0:16], C[4 * ii:4 * ii + 4,
                                                     16 * ji:16 * ji + 16])
        for iii in par(0, M % 4 / 3):
            microkernel_3x16_2(
                K, A[4 * (M / 4) + 3 * iii:4 * (M / 4) + 3 * iii + 3,
                     0:K], B[0:K, 16 * ji:16 * ji + 16],
                C[4 * (M / 4) + 3 * iii:4 * (M / 4) + 3 * iii + 3,
                  16 * ji:16 * ji + 16])
        for ioo in par(0, M % 4 % 3):
            for jo in par(0, 16):
                for k in par(0, K):
                    C[ioo + M % 4 / 3 * 3 + M / 4 * 4,
                      16 * ji + jo] += A[ioo + M % 4 / 3 * 3 + M / 4 * 4,
                                         k] * B[k, 16 * ji + jo]
    for ii in par(0, M / 4):
        for jii in par(0, N % 16 / 12):
            microkernel_4x12_3(
                K, A[4 * ii:4 * ii + 4, 0:K],
                B[0:K, 16 * (N / 16) + 12 * jii:16 * (N / 16) + 12 * jii + 12],
                C[4 * ii:4 * ii + 4,
                  16 * (N / 16) + 12 * jii:16 * (N / 16) + 12 * jii + 12])
        for ko in par(0, K / 28):
            for jiii in par(0, N % 16 % 12 / 3):
                microkernel_4x3x28_4(
                    A[4 * ii:4 * ii + 4, 28 * ko:28 * ko + 28],
                    B[28 * ko:28 * ko + 28,
                      12 * (N % 16 / 12) + 16 * (N / 16) + 3 * jiii:12 *
                      (N % 16 / 12) + 16 * (N / 16) + 3 * jiii + 3],
                    C[4 * ii:4 * ii + 4,
                      12 * (N % 16 / 12) + 16 * (N / 16) + 3 * jiii:12 *
                      (N % 16 / 12) + 16 * (N / 16) + 3 * jiii + 3])
            for jooo in par(0, N % 16 % 12 % 3):
                for io in par(0, 4):
                    for ki in par(0, 28):
                        C[4 * ii + io,
                          jooo + N % 16 % 12 / 3 * 3 + N % 16 / 12 * 12 +
                          N / 16 * 16] += A[4 * ii + io, 28 * ko + ki] * B[
                              28 * ko + ki, jooo + N % 16 % 12 / 3 * 3 +
                              N % 16 / 12 * 12 + N / 16 * 16]
        for jiii in par(0, N % 16 % 12 / 3):
            for jooo in par(0, 3):
                for io in par(0, 4):
                    for ki in par(0, K % 28):
                        C[4 * ii + io, 3 * jiii + jooo + N % 16 / 12 * 12 +
                          N / 16 * 16] += A[4 * ii + io, ki + K / 28 * 28] * B[
                              ki + K / 28 * 28,
                              3 * jiii + jooo + N % 16 / 12 * 12 + N / 16 * 16]
        for jooo in par(0, N % 16 % 12 % 3):
            for io in par(0, 4):
                for ki in par(0, K % 28):
                    C[4 * ii + io,
                      jooo + N % 16 % 12 / 3 * 3 + N % 16 / 12 * 12 +
                      N / 16 * 16] += A[4 * ii + io, ki + K / 28 * 28] * B[
                          ki + K / 28 * 28, jooo + N % 16 % 12 / 3 * 3 +
                          N % 16 / 12 * 12 + N / 16 * 16]
    for io in par(0, M % 4):
        for jo in par(0, N % 16):
            for k in par(0, K):
                C[io + M / 4 * 4, jo +
                  N / 16 * 16] += A[io + M / 4 * 4, k] * B[k, jo + N / 16 * 16]

```

Now, we will generate a kernel using AVX512 instrinsics using the same method call!

In [3]:
# AVX512
kernel_generator.clear_microkernels()
exo_sgemm_avx512 = kernel_generator.generate_exo_sgemm(driver.AVX512Machine, 4, 16, 32, 32, 32)
kernel_generator.microkernels[0]

```python
def microkernel_4x16_0(K: size, A: [f32][4, K] @ DRAM, B: [f32][K, 16] @ DRAM,
                       C: [f32][4, 16] @ DRAM):
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1
    C_reg: R[4, 1, 16] @ AVX512
    for i in par(0, 4):
        for jo in par(0, 1):
            mm512_loadu_ps(C_reg[i, jo, 0:16], C[i, 16 * jo:16 * jo + 16])
    for k in seq(0, K):
        for i in par(0, 4):
            A_vec: R[16] @ AVX512
            mm512_set1_ps(A_vec, A[i, k:k + 1])
            for jo in par(0, 1):
                B_vec: R[16] @ AVX512
                mm512_loadu_ps(B_vec[0:16], B[k, 16 * jo:16 * jo + 16])
                mm512_fmadd_ps(A_vec, B_vec, C_reg[i, jo, 0:16])
    for i in par(0, 4):
        for jo in par(0, 1):
            mm512_storeu_ps(C[i, 16 * jo:16 * jo + 16], C_reg[i, jo, 0:16])

```