# Tiled (multiplexed), nonsquare, vectorized, matrix multiplication, on AIE

## Boilerplate

In [None]:
from __future__ import annotations

import os
import sys

from aie.extras.context import ExplicitlyManagedModule, RAIIMLIRContext
from aie.extras.dialects.ext import arith, func, linalg
from filelock import FileLock
import numpy as np

from aie.dialects import aie, aiex, aievec, scf, vector
from aie.dialects.aie import (
    AIEDevice,
    DMAChannelDir,
    LockAction,
    WireBundle,
)
from aie.dialects.scf import for_ as range_
import aie.extras.types as T
from aie.ir import AffineMap, AffineDimExpr
from aie.util import tiling_calculator_n_tiles
from aie.xrt import XCLBin
from util import (
    compile_with_vectorization,
    make_xclbin,
)

yield_ = lambda: scf.yield_([])

DMA = WireBundle.DMA
South = WireBundle.South
North = WireBundle.North
S2MM = DMAChannelDir.S2MM
MM2S = DMAChannelDir.MM2S
AcquireGreaterEqual = LockAction.AcquireGreaterEqual
Release = LockAction.Release

M = N = 32

tile_rows_A, tile_cols_A = 2, 1
tile_rows_B, tile_cols_B = 1, 2
tile_rows_C, tile_cols_C = 2, 2

tile_m_A, tile_n_A = M // tile_rows_A, N // tile_cols_A
tile_m_B, tile_n_B = M // tile_rows_B, N // tile_cols_B
tile_m_C, tile_n_C = M // tile_rows_C, N // tile_cols_C

print(f"{tile_m_A=}, {tile_n_A=}")
print(f"{tile_m_B=}, {tile_n_B=}")
print(f"{tile_m_C=}, {tile_n_C=}")

## Context management

In [None]:
ctx = RAIIMLIRContext()

## Vectorized matmul

In [None]:
mod_aievec = ExplicitlyManagedModule()


@func.func(emit=True, sym_visibility="private")
def matmul_i32_i32(
    A: T.memref(tile_m_A, tile_n_A, T.i32()),
    B: T.memref(tile_m_B, tile_n_B, T.i32()),
    C: T.memref(tile_m_C, tile_n_C, T.i32()),
):
    vec16int32 = T.vector(16, T.i32())
    vec16int64 = T.vector(16, T.i64())
    d1 = AffineDimExpr.get(1)
    perm_map = AffineMap.get(2, 0, [d1])

    c0 = arith.constant(0, index=True)
    for j in range_(0, 16):
        c_vec = aievec.upd(vec16int32, C, [j, c0])
        accum = aievec.ups(vec16int64, c_vec)
        for k in range_(0, 32, 8):
            a_vec = aievec.upd(vec16int32, A, [j, k])
            for i in range(0, 8):
                broad_a = aievec.broadcast(vec16int32, a_vec, idx=i)
                b_vec = aievec.upd(vec16int32, B, [k + i, c0])
                accum = aievec.mac_elem(vec16int64, broad_a, b_vec, accum)

            shift_round_sat = aievec.srs(vec16int32, accum, arith.constant(0))
            vector.transfer_write(
                None,
                shift_round_sat,
                C,
                [j, c0],
                permutation_map=perm_map,
                in_bounds=[True],
            )
            yield_()
        yield_()


mod_aievec = mod_aievec.finish()
print(mod_aievec)
print(mod_aievec.operation.verify())

In [None]:
mod_aie = ExplicitlyManagedModule()

## (Manual) switch configuration

In [None]:
def switch_config(tile_0_0, tile_0_1, tile_0_2):
    @aie.switchbox(tile_0_0)
    def switchbox_0_0():
        aie.connect(South, 3, North, 0)
        aie.connect(South, 7, North, 1)
        aie.connect(North, 0, South, 2)
        aie.end()

    @aie.shim_mux(tile_0_0)
    def shim_mux_0_0():
        aie.connect(DMA, 0, North, 3)
        aie.connect(DMA, 1, North, 7)
        aie.connect(North, 2, DMA, 0)
        aie.end()

    @aie.switchbox(tile_0_1)
    def switchbox_0_1():
        aie.connect(South, 0, DMA, 0)
        aie.connect(DMA, 0, North, 0)
        aie.connect(South, 1, DMA, 1)
        aie.connect(DMA, 1, North, 1)
        aie.connect(North, 0, DMA, 2)
        aie.connect(DMA, 2, South, 0)
        aie.end()

    @aie.switchbox(tile_0_2)
    def switchbox_0_2():
        aie.connect(South, 0, DMA, 0)
        aie.connect(South, 1, DMA, 1)
        aie.connect(DMA, 0, South, 0)
        aie.end()

## Data movement
 
For a tiling pattern of `[[a0],[a1]] * [b0, b1]` i.e., `A` becomes two "fat rows", `B` becomes two "fat columns", and `C` is computed across 4 tiles (4 steps in this example). Note, this implies `a0` gets broadcast (in space or time) and `a1` gets broadcast (first product is `a0 * b0`, second product is `a0 * b1` and so on).


In [None]:
def data_movement(tile_0_1, tile_0_2):
    # in
    buffer_0_2_a = aie.buffer(tile_0_2, (tile_m_A, tile_n_A), T.i32())
    buffer_0_2_b = aie.buffer(tile_0_2, (tile_m_B, tile_n_B), T.i32())
    # out
    buffer_0_2_c = aie.buffer(tile_0_2, (tile_m_C, tile_n_C), T.i32())

    # input
    lock_0_1_read_in_a = aie.lock(tile_0_1, init=1)
    lock_0_1_write_out_a = aie.lock(tile_0_1, init=0)

    lock_0_2_read_in_a = aie.lock(tile_0_2, init=1)
    lock_0_2_use_a = aie.lock(tile_0_2, init=0)
    lock_0_2_read_in_b = aie.lock(tile_0_2, init=1)
    lock_0_2_use_b = aie.lock(tile_0_2, init=0)
    lock_0_2_use_c = aie.lock(tile_0_2, init=1)
    lock_0_2_write_out_c = aie.lock(tile_0_2, init=0)

    @aie.mem(tile_0_2)
    def mem_0_2():
        # input
        @aie.dma(S2MM, channel_index=0)
        def dma1():
            aiex.process_bd(lock_0_2_read_in_a, buffer_0_2_a, lock_0_2_use_a)

        @aie.dma(S2MM, channel_index=1)
        def dma2():
            aiex.process_bd(lock_0_2_read_in_b, buffer_0_2_b, lock_0_2_use_b)

        # output
        @aie.dma(MM2S, channel_index=0)
        def dma3():
            aiex.process_bd(lock_0_2_write_out_c, buffer_0_2_c, lock_0_2_use_c)

        aie.end()

    @aie.memtile_dma(tile_0_1)
    def memtile_dma_0_1():
        # input flow
        buffer_0_1_a = aie.buffer(tile_0_1, (tile_m_A, tile_n_A), T.i32())
        buffer_0_1_b = aie.buffer(tile_0_1, (tile_m_B, tile_n_B), T.i32())
        # output flow
        buffer_0_1_c = aie.buffer(tile_0_1, (tile_m_C, tile_n_C), T.i32())

        @aie.dma(S2MM, channel_index=0)
        def dma1():
            aiex.process_bd(lock_0_1_read_in_a, buffer_0_1_a, lock_0_1_write_out_a)

        @aie.dma(MM2S, channel_index=0, num_blocks=2)
        def dma2():
            aiex.process_bd(lock_0_1_write_out_a, buffer_0_1_a, lock_0_1_write_out_a)

        @aie.another_bd(dma2)
        def dma2point5():
            aiex.process_bd(lock_0_1_write_out_a, buffer_0_1_a, lock_0_1_read_in_a)

        aiex.forward_bd(tile_0_1, buffer_0_1_b, s2mm_channel_idx=1)
        aiex.forward_bd(tile_0_1, buffer_0_1_c, s2mm_channel_idx=2)

        aie.end()

    return (
        buffer_0_2_a,
        buffer_0_2_b,
        buffer_0_2_c,
        lock_0_2_use_a,
        lock_0_2_read_in_a,
        lock_0_2_use_b,
        lock_0_2_read_in_b,
        lock_0_2_use_c,
        lock_0_2_write_out_c,
    )

## Tensor addressing and command/control processor

In [None]:
def command_control():
    (
        _,
        _,
        (d1_size_A, d1_stride_A),
        (d0_size_A, d0_stride_A),
    ) = tiling_calculator_n_tiles(
        M, N, n_tile_rows=tile_rows_A, n_tile_cols=tile_cols_A
    )
    (
        _,
        _,
        (d1_size_B, d1_stride_B),
        (d0_size_B, d0_stride_B),
    ) = tiling_calculator_n_tiles(
        M, N, n_tile_rows=tile_rows_B, n_tile_cols=tile_cols_B
    )
    (
        _,
        _,
        (d1_size_C, d1_stride_C),
        (d0_size_C, d0_stride_C),
    ) = tiling_calculator_n_tiles(
        M, N, n_tile_rows=tile_rows_C, n_tile_cols=tile_cols_C
    )

    col = 0
    # in A
    channel_index = 0
    ddr_id = 0
    offsets = [
        0,
        # A tiles are "fat" so need to offset by rows (i.e. d1 dim)
        0 + d1_size_A * d1_stride_A,
    ]
    npu_insts = aiex.npu.get_prolog()
    for i, bd_id in enumerate(range(2)):
        npu_insts.extend(
            aiex.npu.writebd_shimtile(
                col,
                bd_id,
                buffer_length=tile_m_A * tile_n_A,
                buffer_offset=offsets[i],
                ddr_id=ddr_id,
            )
        )
        npu_insts.extend(aiex.npu.write32(MM2S, channel_index, col, bd_id))

    # in B
    channel_index = 1
    ddr_id = 1
    for bd_id in range(bd_id + 1, bd_id + 1 + 4, 2):
        npu_insts.extend(
            aiex.npu.writebd_shimtile(
                col,
                bd_id,
                buffer_length=tile_m_B * tile_n_B,
                buffer_offset=0,
                ddr_id=ddr_id,
                d1_size=d1_size_B,
                d1_stride=d1_stride_B,
                d0_size=d0_size_B,
                d0_stride=d0_stride_B,
            )
        )
        npu_insts.extend(aiex.npu.write32(MM2S, channel_index, col, bd_id))
        bd_id += 1
        # B tiles are "tall" so need to offset by cols (i.e. d0 dim)
        npu_insts.extend(
            aiex.npu.writebd_shimtile(
                col,
                bd_id,
                buffer_length=tile_m_B * tile_n_B,
                buffer_offset=d0_size_B * d0_stride_B,
                ddr_id=ddr_id,
                d1_size=d1_size_B,
                d1_stride=d1_stride_B,
                d0_size=d0_size_B,
                d0_stride=d0_stride_B,
            )
        )
        npu_insts.extend(aiex.npu.write32(MM2S, channel_index, col, bd_id))

    # out C
    channel_index = 0
    ddr_id = 2
    offsets = [
        0,
        0 + d0_size_C * d0_stride_C,
        d1_size_C * d1_stride_C,
        d1_size_C * d1_stride_C + d0_size_C * d0_stride_C,
    ]

    for i, bd_id in enumerate(range(bd_id + 1, bd_id + 1 + 4)):
        npu_insts.extend(
            aiex.npu.writebd_shimtile(
                col,
                bd_id,
                buffer_length=tile_m_C * tile_n_C,
                buffer_offset=offsets[i],
                ddr_id=ddr_id,
                d1_size=d1_size_C,
                d1_stride=d1_stride_C,
                d0_size=d0_size_C,
                d0_stride=d0_stride_C,
            )
        )
        npu_insts.extend(aiex.npu.write32(S2MM, channel_index, col, bd_id))
        npu_insts.extend(
            aiex.npu.sync(
                channel=0,
                column=0,
                column_num=1,
                direction=0,
                row=0,
                row_num=1,
            )
        )

    return npu_insts

## Draw the rest of the owl (assemble device module and orchestrate tiling)

Note, maybe tiling isn't the right word here (multiplexing in time) but this example naturally extends to an actual 2 x 2 tiling.

In [None]:
@aie.device(AIEDevice.npu)
def npu():
    matmul_i32_i32.emit(decl=True)
    tile_0_0 = aie.tile(0, 0)
    tile_0_1 = aie.tile(0, 1)
    tile_0_2 = aie.tile(0, 2)

    switch_config(tile_0_0, tile_0_1, tile_0_2)
    (
        buffer_0_2_a,
        buffer_0_2_b,
        buffer_0_2_c,
        lock_0_2_use_a,
        lock_0_2_read_in_a,
        lock_0_2_use_b,
        lock_0_2_read_in_b,
        lock_0_2_use_c,
        lock_0_2_write_out_c,
    ) = data_movement(tile_0_1, tile_0_2)

    @aie.core(tile_0_2)
    def core():
        for _ in range_(0, tile_rows_C):
            for _ in range_(0, tile_cols_C):
                with (
                    aiex.hold_lock(lock_0_2_use_a, lock_0_2_read_in_a),
                    aiex.hold_lock(lock_0_2_use_b, lock_0_2_read_in_b),
                    aiex.hold_lock(lock_0_2_use_c, lock_0_2_write_out_c),
                ):
                    linalg.fill(0, buffer_0_2_c)
                    matmul_i32_i32(buffer_0_2_a, buffer_0_2_b, buffer_0_2_c)

                yield_()
            yield_()

In [None]:
mod_aie.finish()
print(mod_aie.module)
print(mod_aie.module.operation.verify())

## Compile using chess

In [None]:
compile_with_vectorization(mod_aie, mod_aievec, workdir)

# Run

In [None]:
xclbin_path = make_xclbin(mod_aie, workdir)
with FileLock("/tmp/npu.lock"):
    xclbin = XCLBin(xclbin_path, "MLIR_AIE")
    npu_insts = command_control()
    xclbin.load_npu_instructions(npu_insts)

    wrap_A, wrap_B, wrap_C = map(
        np.asarray, xclbin.mmap_buffers([(M, N), (M, N), (M, N)], np.int32)
    )

    A = np.random.randint(0, 10, (M, N), dtype=np.int32)
    B = np.random.randint(0, 10, (M, N), dtype=np.int32)
    C = np.zeros((M, N), dtype=np.int32)

    np.copyto(wrap_A, A, casting="no")
    np.copyto(wrap_B, B, casting="no")
    np.copyto(wrap_C, C, casting="no")

    xclbin.sync_buffers_to_device()
    xclbin.run()
    print("Running kernel")
    xclbin.wait(30)
    xclbin.sync_buffers_from_device()

    if not np.array_equal(A @ B, wrap_C):
        with np.printoptions(threshold=sys.maxsize, linewidth=sys.maxsize):
            print(A @ B)
            print(wrap_C)
            assert False