# DSTC

This notebook reproduces the salient characteristics of the [DSTC](https://dl.acm.org/doi/10.1109/ISCA52012.2021.00088).

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "..")

from src import utils

## Initialization

Initialize the input tensors. Tensor shapes and densities can be modified below.

**Warning:** Large tensors will overwhelm the video generation. Either:
1. Use small tensors; as a rule of thumb, fewer than 60 computes (e.g., multiplications) should be required.
2. Do not generate a video; remove the `spacetime` specification from the `mapping` before compiling.

In [None]:
K = 4
M = 12
N = 8

M2 = 4
N2 = 4
M1 = 4
N1 = 4
M0 = 2
N0 = 2

density = [0.9, 0.5]
seed = 0

A_KM = Tensor.fromRandom(rank_ids=["K", "M"], shape=[K, M], seed=seed, density=density, name="A")
B_KN = Tensor.fromRandom(rank_ids=["K", "N"], shape=[K, N], seed=seed + 1, density=density, name="B")

## Compile and Run

Below is the TeAAL specification for DSTC_SpGEMM. To simulate the accelerator:
1. Compile it to HiFiber by running the cell, inserting a new cell
2. Run the new cell, which will
    - Execute the kernel; multiplying the above defined matrices
    - Generate visualizations of the actions of the kernel

#### Notes

- Small tensors are required for video generation. If you are using large tensors, remove the spacetime specification to generate a kernel that does not produce videos. Outputs can still be checked below.
- Partition shapes are decreased accordingly above for visualization purposes. The real DSTC uses `M0 = 4`, `M1 = 8`, `N0 = 4`, and `N1 = 16`.

In [None]:
yaml = """
einsum:
  declaration:
    A: [K, M]
    B: [K, N]
    Z: [M, N]
  expressions:
    - Z[m, n] = A[k, m] * B[k, n]
mapping:
  rank-order:
    A: [K, M]
    B: [K, N]
    Z: [M, N]
  partitioning:
    Z:
      M: [uniform_shape(M2), uniform_occupancy(A.M1), uniform_occupancy(A.M0)] 
      N: [uniform_shape(N2), uniform_occupancy(B.N1), uniform_occupancy(B.N0)]
  loop-order:
    Z: [M3, N3, K, M2, N2, M1, N1, M0, N0]
  # spacetime:
  #   Z:
  #     space: [M2, N2, M1, N1, N0]
  #     time: [M3, N3, K, M0]
"""

utils.compile(yaml)

In [None]:
# Autogenerated HiFiber

Z_M3N3M2N2M1N1M0N0 = Tensor(rank_ids=["M3", "N3", "M2", "N2", "M1", "N1", "M0", "N0"], name="Z")
tmp0 = A_KM
tmp1 = tmp0.splitUniform(M2, depth=1)
A_KM3M2I = tmp1
A_KM3M2I.setRankIds(rank_ids=["K", "M3", "M2I"])
tmp2 = B_KN
tmp3 = tmp2.splitUniform(N2, depth=1)
B_KN3N2I = tmp3
B_KN3N2I.setRankIds(rank_ids=["K", "N3", "N2I"])
z_m3 = Z_M3N3M2N2M1N1M0N0.getRoot()
A_M3KM2I = A_KM3M2I.swizzleRanks(rank_ids=["M3", "K", "M2I"])
B_N3KN2I = B_KN3N2I.swizzleRanks(rank_ids=["N3", "K", "N2I"])
a_m3 = A_M3KM2I.getRoot()
b_n3 = B_N3KN2I.getRoot()
for m3, (z_n3, a_k) in z_m3 << a_m3:
    for n3, (z_m2, b_k) in z_n3 << b_n3:
        for k, (a_m2i, b_n2i) in a_k & b_k:
            A_M2I = Tensor.fromFiber(rank_ids=["M2I"], fiber=a_m2i, name="A")
            B_N2I = Tensor.fromFiber(rank_ids=["N2I"], fiber=b_n2i, name="B")
            tmp4 = A_M2I
            tmp5 = tmp4.splitEqual(M1)
            A_M2M1I = tmp5
            A_M2M1I.setRankIds(rank_ids=["M2", "M1I"])
            tmp6 = B_N2I
            tmp7 = tmp6.splitEqual(N1)
            B_N2N1I = tmp7
            B_N2N1I.setRankIds(rank_ids=["N2", "N1I"])
            a_m2 = A_M2M1I.getRoot()
            b_n2 = B_N2N1I.getRoot()
            for m2, (z_n2, a_m1i) in z_m2 << a_m2:
                A_M1I = Tensor.fromFiber(rank_ids=["M1I"], fiber=a_m1i, name="A")
                tmp8 = A_M1I
                tmp9 = tmp8.splitEqual(M0)
                A_M1M0 = tmp9
                A_M1M0.setRankIds(rank_ids=["M1", "M0"])
                a_m1 = A_M1M0.getRoot()
                for n2, (z_m1, b_n1i) in z_n2 << b_n2:
                    B_N1I = Tensor.fromFiber(rank_ids=["N1I"], fiber=b_n1i, name="B")
                    tmp10 = B_N1I
                    tmp11 = tmp10.splitEqual(N0)
                    B_N1N0 = tmp11
                    B_N1N0.setRankIds(rank_ids=["N1", "N0"])
                    b_n1 = B_N1N0.getRoot()
                    for m1, (z_n1, a_m0) in z_m1 << a_m1:
                        for n1, (z_m0, b_n0) in z_n1 << b_n1:
                            for m0, (z_n0, a_val) in z_m0 << a_m0:
                                for n0, (z_ref, b_val) in z_n0 << b_n0:
                                    z_ref += a_val * b_val
tmp12 = Z_M3N3M2N2M1N1M0N0
tmp13 = tmp12.swizzleRanks(rank_ids=["M3", "M2", "M1", "M0", "N3", "N2", "N1", "N0"])
tmp14 = tmp13.mergeRanks(depth=4, levels=3, coord_style="absolute")
tmp15 = tmp14.mergeRanks(depth=0, levels=3, coord_style="absolute")
tmp15.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp15

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_matmul(A_KM, B_KN, Z_MN)