# SPADE

This notebook reproduces the salient characteristics of the [SPADE](https://dl.acm.org/doi/10.1145/3579371.3589054) accelerator.

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "..")

from src import utils

## Initialization

Initialize the input tensors. Tensor shapes and densities can be modified below.

**Warning:** Large tensors will overwhelm the video generation. Either:
1. Use small tensors; as a rule of thumb, fewer than 60 computes (e.g., multiplications) should be required.
2. Do not generate a video; remove the `spacetime` specification from the `mapping` before compiling.

In [None]:
K = 8
M = 12
N = 2

row_panel_shape = 2
col_panel_shape = 2
tiles_before_barrier = 2
elems_before_barrier = col_panel_shape * tiles_before_barrier
num_pes = 3
rows_to_all_pes = num_pes * row_panel_shape

density = [0.9, 0.5]
seed = 1

A_MK = Tensor.fromRandom(rank_ids=["M", "K"], shape=[M, K], seed=seed, density=density, name="A")
B_KN = Tensor.fromRandom(rank_ids=["K", "N"], shape=[K, N], seed=seed + 1, density=[1, 1], name="B")
C_MN = Tensor.fromRandom(rank_ids=["M", "N"], shape=[M, N], seed=seed + 2, density=[1, 1], name="C")

## Compile and Run

Below is the TeAAL specifications for SPADE. To simulate the accelerator:
1. Compile it to HiFiber by running the cell, inserting a new cell
2. Run the new cell, which will
    - Execute the kernel; multiplying the above defined matrices
    - Generate visualizations of the actions of the kernel

#### Notes

- Small tensors are required for video generation. If you are using large tensors, remove the spacetime specification to generate a kernel that does not produce videos. Outputs can still be checked below.
- Partition shapes are decreased accordingly above for visualization purposes. The real SPADE uses problem-specific values for `row_panel_shape`, `col_panel_shape`, `tiles_before_barrier`, and varies `num_pes` depending on the hardware.

### SpMM

The following is the TeAAL specification for the SpMM kernel of SPADE:

```yaml
einsum:
    declaration:
        A: [K, M]
        B: [K, N]
        Z: [M, N]
    expressions:
        - Z[m, n] = A[k, m] * B[k, n]
mapping:
    rank-order:
        A: [M, K]
        B: [K, N]
        Z: [M, N]
    partitioning:
        Z:
            K: [uniform_shape(elems_before_barrier), uniform_shape(col_panel_shape)]
            M: [uniform_shape(rows_to_all_pes), uniform_shape(row_panel_shape)]
            (K1, M0, K0): [flatten()]
    loop-order:
        Z: [M2, K2, M1, K1M0K0, N]
    spacetime:
        Z:
            space: [M1.coord]
            time: [M2, K2, K1M0K0, N]
```

However, the current TeAAL compiler cannot flatten partitioned ranks above the bottom rank (in this case, `K1`). We present a TeAAL specification that will compile, but only flattens `M0` and `K0`, and a modified version of the resulting HiFiber that flattens all three ranks.

In [None]:
yaml = """
einsum:
    declaration:
        A: [K, M]
        B: [K, N]
        Z: [M, N]
    expressions:
        - Z[m, n] = A[k, m] * B[k, n]
mapping:
    rank-order:
        A: [M, K]
        B: [K, N]
        Z: [M, N]
    partitioning:
        Z:
            K: [uniform_shape(elems_before_barrier), uniform_shape(col_panel_shape)]
            M: [uniform_shape(rows_to_all_pes), uniform_shape(row_panel_shape)]
            (M0, K0): [flatten()]
    loop-order:
        Z: [M2, K2, M1, K1, M0K0, N]
    spacetime:
        Z:
            space: [M1.coord]
            time: [M2, K2, K1, M0K0, N]
"""

utils.compile(yaml)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_matmul(A_MK, B_KN, Z_MN)

### Modified HiFiber for SpMM

The following loopnest is modified from the version produced by the TeAAL specification to include flattening across `(K1, M0, K0)`.

In [None]:
Z_M2M1M0N = Tensor(rank_ids=["M2", "M1", "M0", "N"], name="Z")
tmp0 = A_MK
tmp1 = tmp0.splitUniform(elems_before_barrier, depth=1)
tmp2 = tmp1.splitUniform(col_panel_shape, depth=2)
A_MK2K1K0 = tmp2
A_MK2K1K0.setRankIds(rank_ids=["M", "K2", "K1", "K0"])
tmp3 = A_MK2K1K0
tmp4 = tmp3.splitUniform(rows_to_all_pes, depth=0)
tmp5 = tmp4.splitUniform(row_panel_shape, depth=1)
A_M2M1M0K2K1K0 = tmp5
A_M2M1M0K2K1K0.setRankIds(rank_ids=["M2", "M1", "M0", "K2", "K1", "K0"])
tmp6 = B_KN
tmp7 = tmp6.splitUniform(elems_before_barrier, depth=0)
tmp8 = tmp7.splitUniform(col_panel_shape, depth=1)
B_K2K1K0N = tmp8
B_K2K1K0N.setRankIds(rank_ids=["K2", "K1", "K0", "N"])
z_m2 = Z_M2M1M0N.getRoot()
A_M2M1K2K1M0K0 = A_M2M1M0K2K1K0.swizzleRanks(rank_ids=["M2", "M1", "K2", "K1", "M0", "K0"])
tmp9 = A_M2M1K2K1M0K0
# Flatten (K1, M0, K0) together
tmp10 = tmp9.flattenRanks(depth=3, levels=2, coord_style="tuple")
A_M2M1K2K1M0K0_flat = tmp10
# Update rank names
A_M2M1K2K1M0K0_flat.setRankIds(rank_ids=["M2", "M1", "K2", "K1M0K0"])
b_k2 = B_K2K1K0N.getRoot()
# Use updated rank names
A_M2K2M1K1M0K0 = A_M2M1K2K1M0K0_flat.swizzleRanks(rank_ids=["M2", "K2", "M1", "K1M0K0"])
a_m2 = A_M2K2M1K1M0K0.getRoot()
canvas = createCanvas(A_M2K2M1K1M0K0, B_K2K1K0N, Z_M2M1M0N)
for m2_pos, (m2, (z_m1, a_k2)) in enumerate(z_m2 << a_m2):
    for k2_pos, (k2, (a_m1, b_k1)) in enumerate(a_k2 & b_k2):
        # Update the name of the A-fiber
        for m1, (z_m0, a_k1m0k0) in z_m1 << a_m1:
            # Remove the K1 loop and add the new flattened K1M0K0 loop
            for k1m0k0_pos, ((k1, m0, k0), a_val) in enumerate(a_k1m0k0):
                z_n = z_m0.getPayloadRef(m0)
                # Update the access to B
                b_n = b_k1.getPayload(k1, k0)
                for n_pos, (n, (z_ref, b_val)) in enumerate(z_n << b_n):
                    z_ref += a_val * b_val
                    # Update the spacetime stamp
                    canvas.addActivity((m2, k2, m1, (k1, m0, k0)), (k2, k1, k0, n), (m2, m1, m0, n), spacetime=((m1 - m2,), (m2_pos, k2_pos, k1m0k0_pos, n_pos)))
tmp11 = Z_M2M1M0N
tmp12 = tmp11.mergeRanks(depth=0, levels=2, coord_style="absolute")
tmp12.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp12
displayCanvas(canvas)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after running the kernel (above cell).

In [None]:
utils.check_matmul(A_MK, B_KN, Z_MN)

### SDDMM

The following is the TeAAL specification for the SDDMM kernel of SPADE:

```yaml
einsum:
    declaration:
        A: [K, M]
        B: [K, N]
        C: [M, N]
        Z: [M, N]
    expressions:
        - Z[m, n] = A[k, m] * B[k, n] * C[m, n]
mapping:
    rank-order:
        A: [M, K]
        B: [K, N]
        C: [M, N]
        Z: [M, N]
    partitioning:
        Z:
            K: [uniform_shape(elems_before_barrier), uniform_shape(col_panel_shape)]
            M: [uniform_shape(rows_to_all_pes), uniform_shape(row_panel_shape)]
            (K1, M0, K0): [flatten()]
    loop-order:
        Z: [M2, K2, M1, K1M0K0, N]
    spacetime:
        Z:
            space: [M1.coord]
            time: [M2, K2, K1M0K0, N]
```

However, the current TeAAL compiler cannot flatten partitioned ranks above the bottom rank (in this case, `K1`). We present a TeAAL specification that will compile, but only flattens `M0` and `K0`, and a modified version of the resulting HiFiber that flattens all three ranks.

In [None]:
yaml = """
einsum:
    declaration:
        A: [K, M]
        B: [K, N]
        C: [M, N]
        Z: [M, N]
    expressions:
        - Z[m, n] = A[k, m] * B[k, n] * C[m, n]
mapping:
    rank-order:
        A: [M, K]
        B: [K, N]
        C: [M, N]
        Z: [M, N]
    partitioning:
        Z:
            K: [uniform_shape(elems_before_barrier), uniform_shape(col_panel_shape)]
            M: [uniform_shape(rows_to_all_pes), uniform_shape(row_panel_shape)]
            (M0, K0): [flatten()]
    loop-order:
        Z: [M2, K2, M1, K1, M0K0, N]
    spacetime:
        Z:
            space: [M1.coord]
            time: [M2, K2, K1, M0K0, N]
"""

utils.compile(yaml)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_sddmm(A_MK, B_KN, C_MN, Z_MN)

### Modified HiFiber for SDDMM

The following loopnest is modified from the version produced by the TeAAL specification to include flattening across `(K1, M0, K0)`.

In [None]:
Z_M2M1M0N = Tensor(rank_ids=["M2", "M1", "M0", "N"], name="Z")
tmp0 = A_MK
tmp1 = tmp0.splitUniform(rows_to_all_pes, depth=0)
tmp2 = tmp1.splitUniform(row_panel_shape, depth=1)
A_M2M1M0K = tmp2
A_M2M1M0K.setRankIds(rank_ids=["M2", "M1", "M0", "K"])
tmp3 = A_M2M1M0K
tmp4 = tmp3.splitUniform(elems_before_barrier, depth=3)
tmp5 = tmp4.splitUniform(col_panel_shape, depth=4)
A_M2M1M0K2K1K0 = tmp5
A_M2M1M0K2K1K0.setRankIds(rank_ids=["M2", "M1", "M0", "K2", "K1", "K0"])
tmp6 = B_KN
tmp7 = tmp6.splitUniform(elems_before_barrier, depth=0)
tmp8 = tmp7.splitUniform(col_panel_shape, depth=1)
B_K2K1K0N = tmp8
B_K2K1K0N.setRankIds(rank_ids=["K2", "K1", "K0", "N"])
tmp9 = C_MN
tmp10 = tmp9.splitUniform(rows_to_all_pes, depth=0)
tmp11 = tmp10.splitUniform(row_panel_shape, depth=1)
C_M2M1M0N = tmp11
C_M2M1M0N.setRankIds(rank_ids=["M2", "M1", "M0", "N"])
z_m2 = Z_M2M1M0N.getRoot()
A_M2M1K2K1M0K0 = A_M2M1M0K2K1K0.swizzleRanks(rank_ids=["M2", "M1", "K2", "K1", "M0", "K0"])
tmp12 = A_M2M1K2K1M0K0
# Flatten (K1, M0, K0) together
tmp13 = tmp12.flattenRanks(depth=3, levels=2, coord_style="tuple")
A_M2M1K2K1M0K0_flat = tmp13
# Update rank names
A_M2M1K2K1M0K0_flat.setRankIds(rank_ids=["M2", "M1", "K2", "K1M0K0"])
b_k2 = B_K2K1K0N.getRoot()
c_m2 = C_M2M1M0N.getRoot()
# Use updated rank names
A_M2K2M1K1M0K0 = A_M2M1K2K1M0K0_flat.swizzleRanks(rank_ids=["M2", "K2", "M1", "K1M0K0"])
a_m2 = A_M2K2M1K1M0K0.getRoot()
canvas = createCanvas(A_M2K2M1K1M0K0, B_K2K1K0N, C_M2M1M0N, Z_M2M1M0N)
for m2_pos, (m2, (z_m1, (a_k2, c_m1))) in enumerate(z_m2 << (a_m2 & c_m2)):
    for k2_pos, (k2, (a_m1, b_k1)) in enumerate(a_k2 & b_k2):
        # Update the name of the A-fiber
        for m1, (z_m0, (a_k1m0k0, c_m0)) in z_m1 << (a_m1 & c_m1):
            # Remove the K1 loop and add the new flattened K1M0K0 loop
            for k1m0k0_pos, ((k1, m0, k0), a_val) in enumerate(a_k1m0k0):
                z_n = z_m0.getPayloadRef(m0)
                # Update the access to B
                b_n = b_k1.getPayload(k1, k0)
                c_n = c_m0.getPayload(m0)
                for n_pos, (n, (z_ref, (b_val, c_val))) in enumerate(z_n << (b_n & c_n)):
                    z_ref += a_val * b_val * c_val
                    # Update the spacetime stamp
                    canvas.addActivity((m2, k2, m1, (k1, m0, k0)), (k2, k1, k0, n), (m2, m1, m0, n), (m2, m1, m0, n), spacetime=((m1 - m2,), (m2_pos, k2_pos, k1m0k0_pos, n_pos)))
tmp14 = Z_M2M1M0N
tmp15 = tmp14.mergeRanks(depth=0, levels=2, coord_style="absolute")
tmp15.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp15
displayCanvas(canvas)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after running the kernel (above cell).

In [None]:
utils.check_sddmm(A_MK, B_KN, C_MN, Z_MN)