## Step-by-step GEMM implementation in EXO. With some very rough documentation. ##
##### Author: Julian Bellavita, UC Berkeley #####

In [None]:
from __future__ import annotations
from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *

In [None]:

def print_output(fn):
    out = fn.c_code_str()
    print("void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C )")
    print(out.split("void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C )")[2])

In [None]:
"""
Initial sgemm function
"""
@proc
def SGEMM(M: size, N: size, K: size, A: f32[M, K], B: f32[K, N], C: f32[M, N]):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1

    for k in par(0, K):
        for i in par(0, M):
            for j in par(0, N):
                C[i, j] += A[i, k]*B[k, j]

#print_output(SGEMM)

In [None]:
###Define kernel constants
VEC_W = 16

M_REG_BLK = 6
N_REG_BLK = (4 * VEC_W)

M_L1_FAC = 44
N_L1_FAC = 1

M_L1_BLK = M_REG_BLK * M_L1_FAC
N_L1_BLK = N_REG_BLK * N_L1_FAC
K_L1_BLK = 512

basic_kernel_Mx4 = {}
sgemm_kernel_avx512_Mx4 = {}

In [None]:
###Demonstration of partial_eval() and simplify()
for M in range(1, M_REG_BLK+1):
    basic_kernel_Mx4[M] = (
        SGEMM
            .rename(f'basic_kernel_{M}x4')
            .partial_eval(M, N_REG_BLK)
            .simplify()
    )
print(basic_kernel_Mx4[4].c_code_str())

partial_eval(M=N)

    replaces the upper bounds identified by M with N

simplify()

    removes all statements that always evalute to true, such as assert(1==1)

In [None]:
###.split()
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
    )
print(sgemm_kernel_avx512_Mx4[1].c_code_str())

.split(loop_var, N, split_loop_names)

    splits the loop with variable LOOP_VAR into an outer loop with variable        split_loop_names[0] and a inner loop with variable split_loop_names[1]
    The upper bound of the new loops is determined in the following way:

    The outer loop upper bound is the upper bound of the original loop / N
    The inner loop upper bound is N

    This could be very useful when writing blocking procedures, as well as when writing the macrokernels that schedule the calls to the GEMM microkernel 

for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        sgemm_kernel_avx512_Mx4[M]
            .par_to_seq('for k in _: _')
    )
print_output(sgemm_kernel_avx512_Mx4[1])

In [None]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
    )
print_output(sgemm_kernel_avx512_Mx4[1])

.stage_assn(var, pattern)

    Turns the line that matches PATTERN into the following block of code.
    First, create a float named VAR
    Second, read the memory originally read by PATTERN into VAR
    Third, perform the write operation originally performed in PATTEN on VAR
    Finally, write VAR to the memory location read in step 2

    Example:
    C[i] += A[i]*B[i]
    .stage_assn(C_reg, C[_] += _)
    C_reg = C[i]
    C_reg += A[i]*B[i]
    C[i] = C_reg
    

In [None]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .lift_alloc('C_reg: _', n_lifts=4)
    )
print_output(sgemm_kernel_avx512_Mx4[1])

.lift_alloc(pattern, n_lifts)

    moves the memory allocation matching PATTERN outside of N_LIFTS for loops
    For each loop that PATTERN is lifted out of, the amount of memory allocated to it is multiplied by the upper bound of the loop. This appears to only happen if the upper bound is a static integer, i.e. it won't happen if the upper bound is a variable
    Also changes the way in which PATTERN is accessed in its original location. The precise way it does so is a bit confusing, so I think it is easier to think of it in this way:
    It lets you load contiguous sections of memory that are as large as the upper bound of the first for loop PATTERN is lifted out of. The number of contiguous sections you load is determined by the upper bounds of the other loops you lift PATTERN out of. This is useful for something like blocking, because you can assign a bunch of contiguous blocks.

    Example:
    for (int i=0; i<16; ++i) {
        float C;
        ...
    }

    .lift_alloc('C: _', n_lifts=1)
    
    float *C = malloc(sizeof(*C)*16)
    for (int i=0; i<16; ++i) {
        ...
    }

In [None]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
    )
print_output(sgemm_kernel_avx512_Mx4[1])

.double_fission(pattern1, pattern2, n_lifts)

    Lift PATTERN1 n_lifts loops above its current location and also copies the n_lifts-1 loops above PATTERN, wrapping pattern in those loops.

    Does the same thing to PATTERN2, but it moves it below its original location instead of above. 


In [None]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .set_memory('C_reg', DRAM_STATIC)
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
            .stage_expr('A_vec', 'A[_, _]', memory=DRAM_STATIC)
            .stage_expr('B_vec', 'B[_, _]', memory=DRAM_STATIC)
    )
print(sgemm_kernel_avx512_Mx4[1].c_code_str())

replace_all(expr)

    No idea how it picks what to replace, but my best guess is that it just picks the original loop and replaces it with the expression

replace(expr, pattern)

    Replaces the pattern with the expression.

In [None]:
#Make each of the microkernels. Each one multiplies a a M_r*K_c strip of A with a K_c*N_r strip of B.
#Appears to mirror the style of Goto and van der Geijn
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .set_memory('C_reg', AVX512)
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
            .stage_expr('A_vec', 'A[_, _]', memory=AVX512)
            .stage_expr('B_vec', 'B[_, _]', memory=AVX512)
            # Schedule ops
            .replace(mm512_loadu_ps, 'for ji in _: _ #0')
            .replace(mm512_storeu_ps, 'for ji in _: _ #3')
            .replace_all(mm512_set1_ps)
            .replace_all(mm512_loadu_ps)
            .replace_all(mm512_fmadd_ps)
            # LICM
            .lift_alloc('A_vec: _')
            .fission_after('mm512_set1_ps(_)')
            # Clean up
            .simplify()
    )
print_output(sgemm_kernel_avx512_Mx4[1])

lift_alloc(pattern)

    moves the memory allocation statement matching PATTERN one loop higher

fission_after(pattern)

    moves the generic statement matching PATTERN one loop higher and matches it with an allocation statement.

In [None]:
# This stuff multiplies a M_r*K_c panel of A by a N_r*K_c panel of B
# When the microkernel stuff is added in the cell below this one, it calls the appropriate microkernel
# to multiply the M_r*K_c strip of A by the K_c*N_r strip of B. So it's like a long, thin horizontal panel of A 
# multiplied by a long, vertical panel of B
bottom_panel_kernel = (
    SGEMM
        .rename('bottom_panel_kernel')
        .partial_eval(N=N_REG_BLK)
        .add_assertion(f'M < {M_REG_BLK}')
        .simplify()
)
print(bottom_panel_kernel.c_code_str())

In [None]:
SGEMM_WINDOW = (SGEMM.rename('SGEMM_WINDOW')
                .set_window('A', True)
                .set_window('B', True)
                .set_window('C', True))

# Constants for scheduling
VEC_W = 16

M_REG_BLK = 6
N_REG_BLK = (4 * VEC_W)

M_L1_FAC = 44
N_L1_FAC = 1

M_L1_BLK = M_REG_BLK * M_L1_FAC
N_L1_BLK = N_REG_BLK * N_L1_FAC
K_L1_BLK = 512

basic_kernel_Mx4 = {}
sgemm_kernel_avx512_Mx4 = {}
for M in range(1, M_REG_BLK + 1):
    basic_kernel_Mx4[M] = (
        SGEMM_WINDOW
            .rename(f'basic_kernel_{M}x4')
            .partial_eval(M, N_REG_BLK)
            .simplify()
    )
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            # Vectorize columns
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .set_memory('C_reg', AVX512)
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
            # Stage A & B
            .stage_expr('A_vec', 'A[_, _]', memory=AVX512)
            .stage_expr('B_vec', 'B[_, _]', memory=AVX512)
            # Schedule ops
            .replace(mm512_loadu_ps, 'for ji in _: _ #0')
            .replace(mm512_storeu_ps, 'for ji in _: _ #3')
            .replace_all(mm512_set1_ps)
            .replace_all(mm512_loadu_ps)
            .replace_all(mm512_fmadd_ps)
            # LICM
            .lift_alloc('A_vec: _')
            .fission_after('mm512_set1_ps(_)')
            # Clean up
            .simplify()
    )

bottom_panel_kernel = (
    SGEMM_WINDOW
        .rename('bottom_panel_kernel')
        .partial_eval(N=N_REG_BLK)
        .add_assertion(f'M < {M_REG_BLK}')
        .simplify()
)

bottom_panel_kernel_scheduled = (
    bottom_panel_kernel
        .rename('bottom_panel_kernel_scheduled')
        # Specialize branches (simplify needed to unify with basic kernels)
        .specialize('for k in _: _ #0',
                    [f'M == {i}' for i in range(1, M_REG_BLK)])
        .simplify()
        #
        .replace_all(basic_kernel_Mx4[1])
        .replace_all(basic_kernel_Mx4[2])
        .replace_all(basic_kernel_Mx4[3])
        .replace_all(basic_kernel_Mx4[4])
        .replace_all(basic_kernel_Mx4[5])
        #
        .call_eqv(sgemm_kernel_avx512_Mx4[1], 'basic_kernel_1x4(_)')
        .call_eqv(sgemm_kernel_avx512_Mx4[2], 'basic_kernel_2x4(_)')
        .call_eqv(sgemm_kernel_avx512_Mx4[3], 'basic_kernel_3x4(_)')
        .call_eqv(sgemm_kernel_avx512_Mx4[4], 'basic_kernel_4x4(_)')
        .call_eqv(sgemm_kernel_avx512_Mx4[5], 'basic_kernel_5x4(_)')
        #
        .simplify()
)
print(bottom_panel_kernel_scheduled.c_code_str())

specialize(pattern, [condition_lst])

    creates special branches for each of the conditions in CONDITION_LST. Inserts above the statement that matches PATTERN. 

call_eqv(fn, pattern)

    replaces function calls that match PATTERN with calls to FN


In [None]:
# Now for the right kernel
# This one is basically just the microkernel (M_r*K_c and K_c*N_r), but it handles cases where M_r is the largest it can be (6 in this case)
# There is a case for each possible value of N_r/VEC_W, or in other words, for how many vectors can fit into the register block
# used for N.

right_panel_kernel = (
    SGEMM_WINDOW
        .rename('right_panel_kernel')
        .partial_eval(M=M_REG_BLK)
        .add_assertion(f'N / {VEC_W} < 4')
        .simplify()
)
print(right_panel_kernel.c_code_str())


In [None]:
right_panel_kernel_opt = (
    right_panel_kernel
        .rename('right_panel_kernel_opt')
        #
        .stage_assn('C_reg', 'C[_] += _')
        .split('j', VEC_W, ['jo', 'ji'], tail='cut')
        .bound_and_guard('for ji in _: _ #1')
        .fission_after('for jo in _: _', n_lifts=2)
        #
        .par_to_seq('for k in _: _')
        #
        .lift_alloc('C_reg: _', n_lifts=4)
        .reorder_before('C_reg: _ #1')
        #
        .fission_after('C_reg[_] = _', n_lifts=4)
        .fission_after('C_reg[_] += _', n_lifts=4)
        #
        .reorder_before('for i in _: _ #3')
        .reorder_before('for i in _: _ #2')
        #
        .reorder_before('for k in _: _ #1')
        #
        .set_memory('C_reg', AVX512)
        #
        .stage_expr('A_reg', 'A[_]', memory=AVX512)
        .stage_expr('B_reg', 'B[_]', memory=AVX512)
        #
        .replace_all(mm512_set1_ps)
        .replace_all(mm512_fmadd_ps)
        .replace(mm512_loadu_ps, 'for ji in _: _ #0')
        .replace(mm512_loadu_ps, 'for ji in _: _ #1')
        .replace(mm512_storeu_ps, 'for ji in _: _ #2')
        #
        .replace(mm512_maskz_loadu_ps, 'for ji in _: _ #0')
        .replace(mm512_mask_storeu_ps, 'for ji in _: _ #1')
        #
        .stage_expr('A_reg2', 'A[_] #1', memory=AVX512, n_lifts=2)
        .stage_expr('B_reg2', 'B[_] #1', memory=AVX512, n_lifts=2)
        #
        .replace_all(mm512_mask_set1_ps)
        .replace_all(mm512_mask_fmadd_ps)
        .replace_all(mm512_maskz_loadu_ps)
        #
        .fuse_loop('for i in _: _ #0', 'for i in _: _ #1')
        .fuse_loop('for k in _: _ #0', 'for k in _: _ #1')
        .fuse_loop('for i in _: _ #1', 'for i in _: _ #2')
        .fuse_loop('for i in _: _ #2', 'for i in _: _ #3')
        #
        .simplify()
)
print(right_panel_kernel_opt.c_code_str())

fission_after(pattern, n_lifts)

    copies the code below the loop identified by PATTERN, moves it N_LIFTS
    loops outside of its original location, and places it at the bottom of the last
    loop it was moved out of
    

bound_and_guard(pattern)

    Finds the loop identified by PATTERN and creates a loop below it that TODO    

reorder_before(pattern)

    Slightly confused by this one. My best guess is that it creates a copy of PATTERN and replaces the identified numerical reference to PATTERN (the #N thing) with the copy of PATTERN. This is what it does when PATTERN is a memory allocation statement.

    When PATTERN is a loop, it matches the #Nth instance of the loop, and moves it one loop above its current position.

stage_expr(name, pattern, memory, n_lifts)

    Matches the #Nth occurance of PATTERN and create an expression that loads it into NAME, which is located in MEMORY (like AVX512 for example). 
    N_LIFTS determines how many loops the memory allocation and the expression are lifted up out of, similar to other applications of N_LIFTS 

fuse_loop(loop1, loop2)

    Fuses the body of loop1 with loop2, so now both bodies execute inside of a single loop.

In [None]:
# This handles the possible N_r/VEC_W values and creates inline calls to the right_panel_kernel for each case.
right_panel_kernel_scheduled = (
    right_panel_kernel
        .rename('right_panel_kernel_scheduled')
        #
        .replace_all(right_panel_kernel)
        #
        .specialize('right_panel_kernel(_) #0',
                    [f'(N / {VEC_W}) == {i}' for i in range(N_REG_BLK // VEC_W)])
        #
        .repeat(Procedure.call_eqv, right_panel_kernel_opt,
                'right_panel_kernel(_)')
        .repeat(Procedure.inline, 'right_panel_kernel_opt(_)')
        #
        .simplify()
        #
        .repeat(Procedure.inline_window, 'A = _')
        .repeat(Procedure.inline_window, 'B = _')
        .repeat(Procedure.inline_window, 'C = _')
        #
        .simplify()
)
print(right_panel_kernel_scheduled.c_code_str())

repeat(Procedure.fn, arg, pattern)

    Calls PATTERN.FN(ARG) on every match of PATTERN. To be clear, each match of PATTERN is set equal to the Procedure instance that FN is called upon.

In [None]:
#calls either the microkernel, bottom_panel, or right_panel depending on whether or not the register sizes evenly divide the block sizes.
sgemm_above_kernel = (
    SGEMM_WINDOW
        .rename('sgemm_above_kernel')
        # Split up into cases
        .split('j', N_REG_BLK, ['jo', 'ji'], tail='cut_and_guard')
        .split('i', M_REG_BLK, ['io', 'ii'], tail='cut_and_guard')
        .fission_after('for jo in _: _ #0', n_lifts=2)
        .reorder('ii #0', 'jo')
        .fission_after('for io in _: _')
        .reorder('k #0', 'io')
        .reorder('k #0', 'jo')
        .lift_if('if N % _ > 0: _ #0', n_lifts=3)
        .reorder('k', 'io')
        .lift_if('if M % _ > 0: _ #0')
        .fission_after('for jo in _: _ #1', n_lifts=2)
        .reorder('ii', 'jo')
        .reorder('k', 'jo')
        .lift_if('if N % _ > 0: _ #1', n_lifts=2)
        # Main block
        .replace_all(basic_kernel_Mx4[6])
        .call_eqv(sgemm_kernel_avx512_Mx4[6], 'basic_kernel_6x4(_)')
        # Right panel
        .replace_all(right_panel_kernel)
        .call_eqv(right_panel_kernel_scheduled, 'right_panel_kernel(_)')
        # Bottom panel
        .replace_all(bottom_panel_kernel)
        .call_eqv(bottom_panel_kernel_scheduled, 'bottom_panel_kernel(_)')
        # TODO: bottom-right tile
        .simplify()
)
print(sgemm_above_kernel.c_code_str())

In [None]:
#Handles blocking and all possible edge cases involving block size divisibility with the complete dimensions
sgemm_exo = (
    SGEMM
        .rename('sgemm_exo')
        # Split all loops
        .split('k', K_L1_BLK, ['ko', 'ki'], tail='cut_and_guard')
        .split('i', M_L1_BLK, ['io', 'ii'], tail='cut_and_guard')
        .split('j', N_L1_BLK, ['jo', 'ji'], tail='cut_and_guard')
        # Explode into 8 cases
        .fission_after('for io in _: _', n_lifts=2)
        .fission_after('for jo in _: _', n_lifts=4)
        # Case 1:
        .reorder('ki', 'io')
        .reorder('ii', 'jo')
        .reorder('ki', 'jo')
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 2:
        .lift_if('if N % _ > 0: _ #0', n_lifts=4)
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 3:
        .lift_if('if M % _ > 0: _ #0', n_lifts=2)
        .reorder('ki', 'jo')
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 4:
        .lift_if('if M % _ > 0: _ #1', n_lifts=2)
        .lift_if('if N % _ > 0: _ #1', n_lifts=3)
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 5:
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 6:
        .lift_if('if N % _ > 0: _ #2', n_lifts=3)
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 7:
        .lift_if('if M % _ > 0: _ #2')
        .reorder('ki', 'jo')
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        # Case 8:
        .lift_if('if M % _ > 0: _ #3')
        .lift_if('if N % _ > 0: _ #3', n_lifts=2)
        .replace(SGEMM_WINDOW, 'for ki in _: _ #0')
        ## Case 1 memory staging
        .stage_window('A1_cache', 'A[_] #0', DRAM_STATIC)
        .stage_window('B1_cache', 'B[_] #0', DRAM_STATIC)
        .par_to_seq('for ko in _: _ #0')
        .par_to_seq('for io in _: _ #0')
        .par_to_seq('for jo in _: _ #0')
        .lift_alloc('A1_cache: _', n_lifts=3)
        .lift_alloc('B1_cache: _', n_lifts=3)
        .fission_after('for i0 in _: _ #0')
        ## Case 2 memory staging
        .stage_window('B2_cache', 'B[_] #1', DRAM_STATIC)
        .bound_alloc('B2_cache: _', [None, f'{N_L1_BLK}'])
        .lift_alloc('B2_cache: _')
        .fission_after('for i0 in _: _ #2')
        ## Case 3 memory staging
        .stage_window('B3_cache', 'B[_] #2', DRAM_STATIC)
        ## Case 4 memory staging
        .stage_window('B4_cache', 'B[_] #3', DRAM_STATIC)
        .bound_alloc('B4_cache: _', [None, f'{N_L1_BLK}'])
        ## Case 5 memory staging
        .stage_window('B5_cache', 'B[_] #4', DRAM_STATIC)
        .bound_alloc('B5_cache: _', [f'{K_L1_BLK}', None])
        ## Case 6 memory staging
        .stage_window('B6_cache', 'B[_] #5', DRAM_STATIC)
        .bound_alloc('B6_cache: _', [f'{K_L1_BLK}', f'{N_L1_BLK}'])
        # .lift_alloc('B6_cache: _')
        # .fission_after('for i0 in _: _ #6')
        ## Case 7 memory staging
        .stage_window('B7_cache', 'B[_] #6', DRAM_STATIC)
        .bound_alloc('B7_cache: _', [f'{K_L1_BLK}', None])
        ## Case 8 memory staging
        .stage_window('B8_cache', 'B[_] #7', DRAM_STATIC)
        .bound_alloc('B8_cache: _', [f'{K_L1_BLK}', f'{N_L1_BLK}'])
        ## Replace SGEMM_WINDOW with optimized form
        # These must come AFTER bound_alloc since the internal check-effects
        # is a whole program analysis that is VERY expensive
        .repeat(Procedure.call_eqv, sgemm_above_kernel, 'SGEMM_WINDOW(_)')
        # Clean up
        .simplify()
)
#print(sgemm_exo.c_code_str())

reorder(var1, var2)

    swaps the positions of the #Nth occurance of VAR1 with VAR2

lift_if(pattern, n_lifts)

    moves the if statement matching PATTERN N_LIFTS brackets above its current location