In [None]:
from __future__ import annotations
from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.platforms.neon import *
from exo.syntax import *

In [None]:
@proc
def SGEMM(
    M: size,
    N: size,
    K: size,
    A: f32[M, K],
    B: f32[K, N],
    C: f32[M, N]
):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1

    for i in par(0, M):
        for j in par(0, N):
            for k in par(0, K):
                C[i, j] += A[i, k] * B[k, j]
print(SGEMM.c_code_str())

In [None]:
#Multiply a M_C*K_C block of A and a K_C*N panel of B. Calls the microkernel for each strip of B
microkernel = (sgemm_win
                .rename('microkernel')
                .partial_eval(N_r,M_r)
                .simplify())
GEBP = (SGEMM
            .rename("GEBP")
            #Partial eval
            .partial_eval(M=M_c)
            .partial_eval(K=K_c)
            #Tile the block of A and the panel of B
            .split('i', N_r, ['io', 'ii'], tail='cut_and_guard')
            .split('j', M_r, ['jo', 'ji'], tail='cut_and_guard')
            #Handle edge case
            .fission_after('for jo in _: _', n_lifts=2)
            #reorder so 
            .reorder('ii','jo')
            #.unroll('io') #Can't unroll loops that lack a constant bound
            .replace_all(microkernel)
            .simplify()

)
print(GEBP.c_code_str())