## 🧮 Performance scaling for GEMM 

In this example, we demonstrate how to leverage **L2 buffer reuse** and enhance **parallelism** in a design by applying ARIES primitives to the GEMM operation.

In [None]:
import os
import sys
cur_dir = os.getcwd()
aries_path = cur_dir + "/../../../../"
sys.path.append(aries_path)
from frontend import *
from IPython import get_ipython

### 🔄 Dataflow of Default GEMM

The diagram below illustrates the dataflow of the default GEMM implementation we previously explored.

<img src="../images/gemm_dataflow.png" alt="GEMM" width="800"/>

In [None]:
# GEMM: C[i0, j0] += A[i0, k0] * B[k0, j0]
I, J, K = 6144, 6144, 6144
TI, TJ, TK = 32, 32, 32
grid = (I // TI, J // TJ, K // TK)  # grid must be a tuple

In [None]:
@task_kernel(external_path="aie1/adf/kernel_mm/aie_fp32_v0", para = [TI, TJ, TK])
def kernel_gemm(TileA: float32[TI, TK], 
                TileB: float32[TK, TJ], 
                TileC: float32[TI, TJ]):
    for i0 in range(0, TI):
        for j0 in range(0, TJ):
            TileC[i0, j0] = float32(0)
            for k0 in range(0, TK):
                TileC[i0, j0] += TileA[i0, k0] * TileB[k0, j0]

@task_tile(False)
def gemm(A: float32[I, K], B: float32[K, J], 
         C: float32[I, J], **kwargs):
    i, j, k = aries.tile_ranks(**kwargs)

    L1_A = aries.buffer((TI, TK), "float32")
    L1_B = aries.buffer((TK, TJ), "float32")
    L1_C = aries.buffer((TI, TJ), "float32")
    
    # Compute tile slices for multiple dimensions
    ti = aries.arange(i*TI, (i+1)*TI)  # I tile range
    tj = aries.arange(j*TJ, (j+1)*TJ)  # J tile range
    tk = aries.arange(k*TK, (k+1)*TK)  # K tile range
    
    L1_A = aries.load(A, (ti, tk))
    L1_B = aries.load(B, (tk, tj))
    kernel_gemm(L1_A, L1_B, L1_C)
    aries.accstore(L1_C, C, (ti, tj))

@task_top()
def top(A: float32[I, K], B: float32[K, J], C: float32[I, J]):
    gemm_task = gemm[grid](A, B, C)
    return gemm_task

# Get the input cells that contains the decorators
cell_codes = get_ipython().user_ns["In"][2:4]
# Join them into one string, with a newline between each cell
all_code = "\n".join(cell_codes)

# Initialize the buffers
np.random.seed(0)
A = np.random.rand(I, K).astype(np.float32)
B = np.random.rand(K, J).astype(np.float32)
C = np.zeros((I, J)).astype(np.float32)

# Execute on CPU
gemm_task = top(A, B, C)

### 🔄 Dataflow of GEMM after Optimizations

The diagram below illustrates the optimized GEMM dataflow after applying **L2 buffer reuse** and enhancing **parallelism** using ARIES primitives. These optimizations reduce redundant memory accesses and enable more efficient computation across tiles.

<img src="../images/gemm_dataflow_opt.png" alt="GEMM" width="800"/>

In [None]:
# Specify primitives to optimize hardware design
sch = Schedule(gemm_task)

############# Primitives #############
sch.parallel(gemm_task, [1, 1, 2]) # AIE Array Parallelism
sch.l2buffer(gemm_task, [2, 2, 1]) # L2 buffer data reuse. The order is [i, j, k] which corresponds to: i, j, k = aries.tile_ranks()
sch.bufsel(gemm_task, [1, 1, 0]) # Select the type of buffer of A, B, C, 1:BRAM; 0:URAM
######################################

sch.to("VCK190")

In [None]:
# Set the project dir and template dir
prj_dir= cur_dir + '/project_gemm'
temp_dir= aries_path + '/templates'
# Generate Initial MLIR and ARIES Opts
sch.build(all_code, prj_dir, temp_dir)
sch.compile(aries_path, prj_dir)