# FuseMax

This notebook reproduces the salient characteristics of the [FuseMax](https://arxiv.org/abs/2406.10491) accelerator.

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "..")

import math

from src import utils

## Initialization

Initialize the input tensors. Tensor shapes and densities can be modified below.

**Warning:** Large tensors will overwhelm the video generation. Either:
1. Use small tensors; as a rule of thumb, fewer than 60 computes should be required. Note that the current set of parameters (the minimum for each rank to have occupancy `> 1`) results in a 98 frame video.
2. Disable the video generation by commenting out the line `displayCanvas()`. Tensors can still be visualized below.

In [None]:
E = 2
F = 2
M = 4
P = 8

M0 = 2
P0 = 2
P1 = 4

density = [1, 1]
seed = 4

Q_PE = Tensor.fromRandom(rank_ids=["P", "E"], shape=[P, E], seed=seed, density=density, name="Q")
K_ME = Tensor.fromRandom(rank_ids=["M", "E"], shape=[M, E], seed=seed, density=density, name="K")
V_MF = Tensor.fromRandom(rank_ids=["M", "F"], shape=[M, F], seed=seed, density=density, name="V")

## Running FuseMax with HiFiber

The TeAAL compiler is currently not sophisticated enough to compile FuseMax's cascade. We instead directly provide the HiFiber.

### FuseMax TeAAL Specification

The FuseMax cascade is:

$$QK_{m1,m0,p} = Q_{e,p}\times BK_{e,m1,m0}:\bigvee_{e}+(\cup)$$

$$LM_{m1,p} = QK_{m1,m0,p}:\bigvee_{m0}\text{max}(\cup)$$
$$RM_{m1+1,p} = max(RM_{m1,p}, LM_{m1,p})$$

$$SLN_{m1,m0,p} = exp(QK_{m1,m0,p} - RM_{m1+1,p})$$

$$PRM_{m1,p} = exp(RM_{m1,p} - RM_{m1+1,p})$$
$$SPD_{m1,p} = PRM_{m1,p}\times RD_{m1,p}$$
$$SLD_{m1,p} = SLN_{m1,m0,p}:\bigvee_{m0}+(\cup)$$
$$RD_{m1+1,p} = SPD_{m1,p} + SLD_{m1,p}$$

$$SPNV_{m1,f,p} = PRM_{m1,p}\times RNV_{m1,f,p}$$
$$SLNV_{m1,f,p} = SLN_{m1,m0,p}\times BV_{m1,m0,f}:\bigvee_{m0}+(\cup)$$
$$RNV_{m1+1,f,p} = SPNV_{m1,f,p} +SLNV_{m1,f,p}$$

$$AV_{f,p} = RNV_{M1,f,p} / RD_{M1,p}$$

### FuseMax HiFiber

The following is a loop nest representation of HiFiber. Note that this loop nest does not implement any of the software pipelining/interleaving used in the actually FuseMax. We present it in this way for pedagogical reasons.

To speed up execution by skipping the video generation, comment out the last line with the code `displayCanvas()`.

In [None]:
QK_P2M1P1P0M0 = Tensor(rank_ids=["P2", "M1", "P1", "P0", "M0"], name="QK")
tmp0 = Q_PE
tmp1 = tmp0.splitUniform(P1, depth=0)
tmp2 = tmp1.splitUniform(P0, depth=1)
Q_P2P1P0E = tmp2
Q_P2P1P0E.setRankIds(rank_ids=["P2", "P1", "P0", "E"])
tmp3 = K_ME
tmp4 = tmp3.splitUniform(M0, depth=0)
K_M1M0E = tmp4
K_M1M0E.setRankIds(rank_ids=["M1", "M0", "E"])
tmp5 = V_MF
tmp6 = tmp5.splitUniform(M0, depth=0)
V_M1M0F = tmp6
V_M1M0F.setRankIds(rank_ids=["M1", "M0", "F"])
qk_p2 = QK_P2M1P1P0M0.getRoot()
q_p2 = Q_P2P1P0E.getRoot()
k_m1 = K_M1M0E.getRoot()
v_m1 = V_M1M0F.getRoot()
LM_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="LM")
lm_p2 = LM_P2M1P1P0.getRoot()
RM_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="RM", default=-float("inf"))
rm_p2 = RM_P2M1P1P0.getRoot()
SLN_P2M1P1P0M0 = Tensor(rank_ids=["P2", "M1", "P1", "P0", "M0"], name="SLN")
sln_p2 = SLN_P2M1P1P0M0.getRoot()
PRM_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="PRM")
prm_p2 = PRM_P2M1P1P0.getRoot()
RD_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="RD")
rd_p2 = RD_P2M1P1P0.getRoot()
SPD_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="SPD")
spd_p2 = SPD_P2M1P1P0.getRoot()
SLD_P2M1P1P0 = Tensor(rank_ids=["P2", "M1", "P1", "P0"], name="SLD")
sld_p2 = SLD_P2M1P1P0.getRoot()
RNV_P2M1P1P0F = Tensor(rank_ids=["P2", "M1", "P1", "P0", "F"], name="RNV")
rnv_p2 = RNV_P2M1P1P0F.getRoot()
SPNV_P2M1P1P0F = Tensor(rank_ids=["P2", "M1", "P1", "P0", "F"], name="SPNV")
spnv_p2 = SPNV_P2M1P1P0F.getRoot()
SLNV_P2M1P1P0F = Tensor(rank_ids=["P2", "M1", "P1", "P0", "F"], name="SLNV", shape=[P, M, P, P, F])
slnv_p2 = SLNV_P2M1P1P0F.getRoot()
AV_P2P1P0F = Tensor(rank_ids=["P2", "P1", "P0", "F"], name="AV")
av_p2 = AV_P2P1P0F.getRoot()

canvas = createCanvas(Q_P2P1P0E, K_M1M0E, QK_P2M1P1P0M0, LM_P2M1P1P0, RM_P2M1P1P0, SLN_P2M1P1P0M0, PRM_P2M1P1P0, RD_P2M1P1P0, SPD_P2M1P1P0, SLD_P2M1P1P0, V_M1M0F, RNV_P2M1P1P0F, SPNV_P2M1P1P0F, SLNV_P2M1P1P0F, AV_P2P1P0F)
for p2_pos, (p2, (av_p1, (slnv_m1, (spnv_m1, (sld_m1, (spd_m1, (prm_m1, (sln_m1, (lm_m1, (qk_m1, q_p1)))))))))) in enumerate(av_p2 << (slnv_p2 << (spnv_p2 << (sld_p2 << (spd_p2 << (prm_p2 << (sln_p2 << (lm_p2 << (qk_p2 << q_p2))))))))):
    for m1_pos, (m1, (slnv_p1, (spnv_p1, (sld_p1, (spd_p1, (prm_p1, (sln_p1, (lm_p1, (qk_p1, (v_m0, k_m0)))))))))) in enumerate(slnv_m1 << (spnv_m1 << (sld_m1 << (spd_m1 << (prm_m1 << (sln_m1 << (lm_m1 << (qk_m1 << (v_m1 & k_m1))))))))):
        for p1_pos, (p1, (slnv_p0, (spnv_p0, (sld_p0, (spd_p0, (prm_p0, (sln_p0, (lm_p0, (qk_p0, q_p0))))))))) in enumerate(slnv_p1 << (spnv_p1 << (sld_p1 << (spd_p1 << (prm_p1 << (sln_p1 << (lm_p1 << (qk_p1 << q_p1)))))))):
            for p0_pos, (p0, (slnv_f, (spnv_f, (sld_ref, (spd_ref, (prm_ref, (sln_m0, (lm_ref, (qk_m0, q_e))))))))) in enumerate(slnv_p0 << (spnv_p0 << (sld_p0 << (spd_p0 << (prm_p0 << (sln_p0 << (lm_p0 << (qk_p0 << q_p0)))))))):
                for m0_pos, (m0, (qk_ref, k_e)) in enumerate(qk_m0 << k_m0):
                    for e_pos, (e, (q_val, k_val)) in enumerate(q_e & k_e):
                        qk_ref += q_val * k_val
                        canvas.addActivity((p2, p1, p0, e), (m1, m0, e), (p2, m1, p1, m0, p0), (), (), (), (), (), (), (), (), (), (), (), spacetime=((m0_pos, p0_pos), (p2_pos, m1_pos, p1_pos, e_pos)))
                        
                    lm_ref <<= max(lm_ref, qk_ref)
                    canvas.addActivity((), (), (p2, m1, p1, m0, p0), (p2, m1, p1, p0), (), (), (), (), (), (), (), (), (), (), (), spacetime=((m0_pos, p0_pos), (p2_pos, m1_pos, p1_pos, E, 0)))

                    
                rm_prev = rm_p2.getPayload(p2, m1, p1, p0)
                rm_next = rm_p2.getPayloadRef(p2, m1 + M0, p1, p0) 
                rm_next <<= max(rm_prev, lm_ref)
                canvas.addActivity((), (), (), (p2, m1, p1, p0), (p2, m1, p1, p0), (), (), (), (), (), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 1)))
                canvas.addActivity((), (), (), (), (p2, m1 + M0, p1, p0), (), (), (), (), (), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 1)))

                for m0_pos, (m0, (sln_ref, (qk_ref, v_f))) in enumerate(sln_m0 << (qk_m0 & v_m0)):
                    sln_ref <<= math.exp(Payload.get(qk_ref - rm_next))
                    canvas.addActivity((), (), (p2, m1, p1, m0, p0), (), (p2, m1 + M0, p1, p0), (p2, m1, p1, m0, p0), (), (), (), (), (), (), (), (), (), spacetime=((m0_pos, p0_pos), (p2_pos, m1_pos, p1_pos, E, 2)))

                    sld_ref += sln_ref
                    canvas.addActivity((), (), (), (), (), (p2, m1, p1, m0, p0), (), (), (), (p2, m1, p1, p0), (), (), (), (), (), spacetime=((m0_pos, p0_pos), (p2_pos, m1_pos, p1_pos, E, 3)))

                    for f_pos, (f, (slnv_ref, v_val)) in enumerate(slnv_f << v_f):
                        slnv_ref += sln_ref * v_val
                        canvas.addActivity((), (), (), (), (), (p2, m1, p1, m0, p0), (), (), (), (), (m1, m0, f), (), (), (p2, m1, p1, p0, f), (), spacetime=((m0_pos, p0_pos), (p2_pos, m1_pos, p1_pos, E + 1, f_pos, 0)))

                prm_ref <<= math.exp(Payload.get(rm_prev - rm_next))
                canvas.addActivity((), (), (), (), (p2, m1, p1, p0), (), (p2, m1, p1, p0), (), (), (), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 2)))
                canvas.addActivity((), (), (), (), (p2, m1 + M0, p1, p0), (), (), (), (), (), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 2)))

                rd_prev = rd_p2.getPayload(p2, m1, p1, p0)
                spd_ref <<= prm_ref * rd_prev
                canvas.addActivity((), (), (), (), (), (), (p2, m1, p1, p0), (p2, m1, p1, p0), (p2, m1, p1, p0), (), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 3)))

                rd_next = rd_p2.getPayloadRef(p2, m1 + M0, p1, p0)
                rd_next <<= spd_ref + sld_ref
                canvas.addActivity((), (), (), (), (), (), (), (p2, m1 + M0, p1, p0), (p2, m1, p1, p0), (p2, m1, p1, p0), (), (), (), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E, 4)))

                rnv_f_prev = rnv_p2.getPayload(p2, m1, p1, p0)
                rnv_f_next = rnv_p2.getPayloadRef(p2, m1 + M0, p1, p0)
                for f_pos, (f, (rnv_next, (spnv_ref, (slnv_val, rnv_prev)))) in enumerate(rnv_f_next << (spnv_f << Fiber.coiterShape((slnv_f, rnv_f_prev)))):
                    spnv_ref <<= rnv_prev * prm_ref
                    canvas.addActivity((), (), (), (), (), (), (p2, m1, p1, p0), (), (), (), (), (p2, m1, p1, p0, f), (p2, m1, p1, p0, f), (), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E + 1, f_pos, 0)))

                    rnv_next <<= spnv_ref + slnv_val
                    canvas.addActivity((), (), (), (), (), (), (), (), (), (), (), (p2, m1 + M0, p1, p0, f), (p2, m1, p1, p0, f), (p2, m1, p1, p0, f), (), spacetime=((p0_pos,), (p2_pos, m1_pos, p1_pos, E + 1, f_pos, 1)))

    rd_m1 = rd_p2.getPayload(p2)
    rd_p1 = rd_m1.getPayload(rd_m1.coords[-1])
    rnv_m1 = rnv_p2.getPayload(p2)
    rnv_p1 = rnv_m1.getPayload(rnv_m1.coords[-1])
    for p1_pos, (p1, (av_p0, (rd_p0, rnv_p0))) in enumerate(av_p1 << (rd_p1 & rnv_p1)):
        for p0_pos, (p0, (av_f, (rd_val, rnv_f))) in enumerate(av_p0 << (rd_p0 & rnv_p0)):
            for f_pos, (f, (av_ref, rnv_val)) in enumerate(av_f << rnv_f):
                av_ref <<= rnv_val / rd_val
                canvas.addActivity((), (), (), (), (), (), (), (p2, M, p1, p0), (), (), (), (p2, M, p1, p0, f), (), (), (p2, p1, p0, f), spacetime=((p0_pos,), (p2_pos, M, p1_pos, f_pos)))

AV_PF = AV_P2P1P0F.mergeRanks(depth=0, levels=2, coord_style="absolute")

displayCanvas(canvas)

### Visualize the Static Tensors

As an alternative to the video generation, the kernel can be visualized by inspecting the involved tensors (below).

In [None]:
displayTensor(Q_P2P1P0E)
displayTensor(K_M1M0E)
displayTensor(QK_P2M1P1P0M0)
displayTensor(LM_P2M1P1P0)
displayTensor(RM_P2M1P1P0)
displayTensor(PRM_P2M1P1P0)
displayTensor(SLN_P2M1P1P0M0)
displayTensor(SPD_P2M1P1P0)
displayTensor(RD_P2M1P1P0)
displayTensor(SLD_P2M1P1P0)
displayTensor(V_M1M0F)
displayTensor(SLNV_P2M1P1P0F)
displayTensor(SPNV_P2M1P1P0F)
displayTensor(RNV_P2M1P1P0F)
displayTensor(AV_P2P1P0F)

### Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_attn(Q_PE, K_ME, V_MF, AV_PF)