In [1]:
from __future__ import annotations
from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *

In [2]:

def print_output(fn):
    out = fn.c_code_str()
    print(out.split("#include <stdio.h>")[1])

In [4]:
"""
Initial sgemm function
"""
@proc
def SGEMM(M: size, N: size, K: size, A: f32[M, K], B: f32[K, N], C: f32[M, N]):
    assert M >= 1
    assert N >= 1
    assert K >= 1
    assert stride(A, 1) == 1
    assert stride(B, 1) == 1
    assert stride(C, 1) == 1

    for k in par(0, K):
        for i in par(0, M):
            for j in par(0, N):
                C[i, j] += A[i, k]*B[k, j]

print_output(SGEMM)


#include <stdlib.h>


// SGEMM(
//     M : size,
//     N : size,
//     K : size,
//     A : f32[M,K]  @DRAM,
//     B : f32[K,N]  @DRAM,
//     C : f32[M,N]  @DRAM
// )
void SGEMM( c_code_str_Context *ctxt, int_fast32_t M, int_fast32_t N, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(M >= 1);
EXO_ASSUME(N >= 1);
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
for (int k = 0; k < K; k++) {
  for (int i = 0; i < M; i++) {
    for (int j = 0; j < N; j++) {
      C[(i) * (N) + (j) * (1)] += A[(i) * (K) + (k) * (1)] * B[(k) * (N) + (j) * (1)];
    }
  }
}
}



In [5]:
###Define kernel constants
VEC_W = 16

M_REG_BLK = 6
N_REG_BLK = (4 * VEC_W)

M_L1_FAC = 44
N_L1_FAC = 1

M_L1_BLK = M_REG_BLK * M_L1_FAC
N_L1_BLK = N_REG_BLK * N_L1_FAC
K_L1_BLK = 512

basic_kernel_Mx4 = {}
sgemm_kernel_avx512_Mx4 = {}

In [6]:
###Demonstration of partial_eval() and simplify()
for M in range(1, M_REG_BLK+1):
    basic_kernel_Mx4[M] = (
        SGEMM
            .rename(f'basic_kernel_{M}x4')
            .partial_eval(M, N_REG_BLK)
            .simplify()
    )
print_output(basic_kernel_Mx4[1])


#include <stdlib.h>


// basic_kernel_1x4(
//     K : size,
//     A : f32[1,K]  @DRAM,
//     B : f32[K,64]  @DRAM,
//     C : f32[1,64]  @DRAM
// )
void basic_kernel_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
for (int k = 0; k < K; k++) {
  for (int i = 0; i < 1; i++) {
    for (int j = 0; j < 64; j++) {
      C[(i) * (64) + (j) * (1)] += A[(i) * (K) + (k) * (1)] * B[(k) * (64) + (j) * (1)];
    }
  }
}
}



partial_eval(A, B)

    replaces all references to loop variable A with the value 1, and replaces all references to the loop variable that is created within A to B
    EXAMPLE
    for (i:M)
      for (j:N)
    partial_eval(M, 64)
    for (i:1)
      for (j:64)

simplify()

    removes all statements that always evalute to true, such as assert(1==1)

In [7]:
###.split()
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
    )
print_output(sgemm_kernel_avx512_Mx4[1])


#include <stdlib.h>


// sgemm_kernel_avx512_1x4(
//     K : size,
//     A : f32[1,K]  @DRAM,
//     B : f32[K,64]  @DRAM,
//     C : f32[1,64]  @DRAM
// )
void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
for (int k = 0; k < K; k++) {
  for (int i = 0; i < 1; i++) {
    for (int jo = 0; jo < 4; jo++) {
      for (int ji = 0; ji < 16; ji++) {
        C[(i) * (64) + (16 * jo + ji) * (1)] += A[(i) * (K) + (k) * (1)] * B[(k) * (64) + (16 * jo + ji) * (1)];
      }
    }
  }
}
}



.split(loop_var, N, split_loop_names)

    splits the loop with variable LOOP_VAR into an outer loop with variable        split_loop_names[0] and a inner loop with variable split_loop_names[1]
    The upper bound of the new loops is determined in the following way:

    The outer loop upper bound is the upper bound of the original loop / N
    The inner loop upper bound is N

    This could be very useful when writing blocking procedures, as well as when writing the macrokernels that schedule the calls to the GEMM microkernel 

for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        sgemm_kernel_avx512_Mx4[M]
            .par_to_seq('for k in _: _')
    )
print_output(sgemm_kernel_avx512_Mx4[1])

In [8]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
    )
print_output(sgemm_kernel_avx512_Mx4[1])


#include <stdlib.h>


// sgemm_kernel_avx512_1x4(
//     K : size,
//     A : f32[1,K]  @DRAM,
//     B : f32[K,64]  @DRAM,
//     C : f32[1,64]  @DRAM
// )
void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
for (int k = 0; k < K; k++) {
  for (int i = 0; i < 1; i++) {
    for (int jo = 0; jo < 4; jo++) {
      for (int ji = 0; ji < 16; ji++) {
        float C_reg;
        C_reg = C[(i) * (64) + (16 * jo + ji) * (1)];
        C_reg += A[(i) * (K) + (k) * (1)] * B[(k) * (64) + (16 * jo + ji) * (1)];
        C[(i) * (64) + (16 * jo + ji) * (1)] = C_reg;
      }
    }
  }
}
}



.stage_assn(var, pattern)

    Turns the line that matches PATTERN into the following block of code.
    First, create a float named VAR
    Second, read the memory originally read by PATTERN into VAR
    Third, perform the write operation originally performed in PATTEN on VAR
    Finally, write VAR to the memory location read in step 2

    Example:
    C[i] += A[i]*B[i]
    .stage_assn(C_reg, C[_] += _)
    C_reg = C[i]
    C_reg += A[i]*B[i]
    C[i] = C_reg
    

In [9]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .lift_alloc('C_reg: _', n_lifts=4)
    )
print_output(sgemm_kernel_avx512_Mx4[1])


#include <stdlib.h>


// sgemm_kernel_avx512_1x4(
//     K : size,
//     A : f32[1,K]  @DRAM,
//     B : f32[K,64]  @DRAM,
//     C : f32[1,64]  @DRAM
// )
void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
float *C_reg = malloc(1 * 4 * 16 * sizeof(*C_reg));
for (int k = 0; k < K; k++) {
  for (int i = 0; i < 1; i++) {
    for (int jo = 0; jo < 4; jo++) {
      for (int ji = 0; ji < 16; ji++) {
        C_reg[(i) * (4 * 16) + (jo) * (16) + (ji) * (1)] = C[(i) * (64) + (16 * jo + ji) * (1)];
        C_reg[(i) * (4 * 16) + (jo) * (16) + (ji) * (1)] += A[(i) * (K) + (k) * (1)] * B[(k) * (64) + (16 * jo + ji) * (1)];
        C[(i) * (64) + (16 * jo + ji) * (1)] = C_reg[(i) * (4 * 16) + (jo) * (16) + (ji) * (1)];
      }
    }
  }
}
free(C_reg);
}



.lift_alloc(pattern, n_lifts)

    moves the memory allocation matching PATTERN outside of N_LIFTS for loops
    For each loop that PATTERN is lifted out of, the amount of memory allocated to it is multiplied by the upper bound of the loop. This appears to only happen if the upper bound is a static integer, i.e. it won't happen if the upper bound is a variable
    Also changes the way in which PATTERN is accessed in its original location. The precise way it does so is a bit confusing, so I think it is easier to think of it in this way:
    It lets you load contiguous sections of memory that are as large as the upper bound of the first for loop PATTERN is lifted out of. The number of contiguous sections you load is determined by the upper bounds of the other loops you lift PATTERN out of. This is useful for something like blocking, because you can assign a bunch of contiguous blocks.

    Example:
    for (int i=0; i<16; ++i) {
        float C;
        ...
    }

    .lift_alloc('C: _', n_lifts=1)
    
    float *C = malloc(sizeof(*C)*16)
    for (int i=0; i<16; ++i) {
        ...
    }

In [10]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
    )
print_output(sgemm_kernel_avx512_Mx4[1])


#include <stdlib.h>


// sgemm_kernel_avx512_1x4(
//     K : size,
//     A : f32[1,K]  @DRAM,
//     B : f32[K,64]  @DRAM,
//     C : f32[1,64]  @DRAM
// )
void sgemm_kernel_avx512_1x4( c_code_str_Context *ctxt, int_fast32_t K, float* A, float* B, float* C ) {
EXO_ASSUME(K >= 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
EXO_ASSUME(1 == 1);
float *C_reg = malloc(1 * 4 * 16 * sizeof(*C_reg));
for (int i = 0; i < 1; i++) {
  for (int jo = 0; jo < 4; jo++) {
    for (int ji = 0; ji < 16; ji++) {
      C_reg[(i) * (4 * 16) + (jo) * (16) + (ji) * (1)] = C[(i) * (64) + (16 * jo + ji) * (1)];
    }
  }
}
for (int k = 0; k < K; k++) {
  for (int i = 0; i < 1; i++) {
    for (int jo = 0; jo < 4; jo++) {
      for (int ji = 0; ji < 16; ji++) {
        C_reg[(i) * (4 * 16) + (jo) * (16) + (ji) * (1)] += A[(i) * (K) + (k) * (1)] * B[(k) * (64) + (16 * jo + ji) * (1)];
      }
    }
  }
}
for (int i = 0; i < 1; i++) {
  for (int jo = 0; jo < 4; jo++) {
    for (int ji = 0; ji < 16; ji++) {
      C[

.double_fission(pattern1, pattern2, n_lifts)

    Lift PATTERN1 n_lifts loops above its current location and also copies the n_lifts-1 loops above PATTERN, wrapping pattern in those loops.

    Does the same thing to PATTERN2, but it moves it below its original location instead of above. 


In [12]:
for M in range(1, M_REG_BLK+1):
    sgemm_kernel_avx512_Mx4[M] = (
        basic_kernel_Mx4[M]
            .rename(f'sgemm_kernel_avx512_{M}x4')
            .split('j', VEC_W, ['jo', 'ji'], perfect=True)
            # Mark k as a reduction loop
            .par_to_seq('for k in _: _')
            # Stage C for reduction
            .stage_assn('C_reg', 'C[_] += _')
            .lift_alloc('C_reg: _', n_lifts=4)
            .double_fission('C_reg[_] = C[_]', 'C_reg[_] += _', n_lifts=4)
            .stage_expr('A_vec', 'A[_, _]', memory=AVX512)
            .stage_expr('B_vec', 'B[_, _]', memory=AVX512)
    )
print_output(sgemm_kernel_avx512_Mx4[1])

TypeError: issubclass() arg 1 must be a class