## Tiled spMspM

This is an example of tiling matrix multiply (spMspM). The tiling is done on both the rows and columns (i.e., both ranks) of the __A__ and __B__ tensor operands and the output tensor (__Z__). In these examples the tiling is done monolitically in a separate step (essentially pre-processed) and then a tiled dataflow is modeled.

First, include some libraries


In [None]:
# Begin - startup boilerplate code

import pkgutil

if 'fibertree_bootstrap' not in [pkg.name for pkg in pkgutil.iter_modules()]:
  !python3 -m pip  install git+https://github.com/Fibertree-project/fibertree-bootstrap --quiet

# End - startup boilerplate code


from fibertree_bootstrap import *
fibertree_bootstrap(style="tree", animation="movie")

## Initialize setup

The following cell just creates some sliders to control the creation of the input operand tensors.

In [None]:
# Initial values

M = 8
N = 8
K = 6
density = [0.9, 0.8]
seed = 10

def set_params(rank_M, rank_N, rank_K, tensor_density, rand_seed):
    global M
    global N
    global K
    global density
    global seed
    
    M = rank_M
    N = rank_N
    K = rank_K
    
    if tensor_density == 'sparse':
        density = [0.9, 0.8]
    elif tensor_density == 'sparser':
        density = [0.9, 0.4]
    else:
        density = [1.0, 1.0]
        
    seed = rand_seed

interactive(set_params,
            rank_M=widgets.IntSlider(min=2, max=12, step=1, value=M),
            rank_N=widgets.IntSlider(min=2, max=12, step=1, value=N),
            rank_K=widgets.IntSlider(min=2, max=12, step=1, value=K),
            tensor_density=['sparser', 'sparse', 'dense'],
            rand_seed=widgets.IntSlider(min=0, max=100, step=1, value=seed))

## Create Input Tensors

Because this notebook tends to use both the original and rank swapped versions of the operands, the tensor names are suffixed with the ordered names of the ranks.


In [None]:
a_MK = Tensor.fromRandom(["M", "K"], [M, K], density, 5, seed=seed)
a_MK.setColor("blue").setName("a_MK")
displayTensor(a_MK)

# Create swapped rank version of a
a_KM = a_MK.swapRanks()
a_KM.setName("a_KM")
displayTensor(a_KM)

b_NK = Tensor.fromRandom(["N", "K"], [N, K], density, 5, seed=2*seed)
b_NK.setColor("green").setName("b_NK")
displayTensor(b_NK)

# Create swapped rank version of b
b_KN = b_NK.swapRanks()
b_KN.setName("b_KN")
displayTensor(b_KN)



## Output Stationary/Inner Product

Plain untiled matrix multiply as a reference.

In [None]:
z_MN = Tensor(rank_ids=["M", "N"], shape=[M, N])
z_MN.setName("z_MN")

a_m = a_MK.getRoot()
b_n = b_NK.getRoot()
z_m = z_MN.getRoot()

canvas = createCanvas(a_MK, b_NK, z_MN)

for m, (z_n, a_k) in z_m << a_m:
    for n, (z_ref, b_k) in z_n << b_n:
        for k, (a_val, b_val) in a_k & b_k:
            z_ref += a_val * b_val
            canvas.addFrame((m, k), (n, k), (m, n))

displayTensor(z_MN)
displayCanvas(canvas)



## Tile the tensors (v1)

Pre-process the tensors into a 2-D tiled form (resulting in 4 ranks) in a rank order that is natural for output-stationary tiles over output-stationary values

In [None]:
M1 = 2
M0 = (M+1)//M1

N1 = 2
N0 = (N+1)//N1

K1 = 2
K0 = (K+1)//K1

a_MKMK = a_MK.splitUniform(M0).splitUniform(K0, depth=2).swapRanks(depth=1)
a_MKMK.setName("a_MKMK")
displayTensor(a_MKMK)


b_NKNK = b_NK.splitUniform(N0).splitUniform(K0, depth=2).swapRanks(depth=1)
b_NKNK.setName("b_NKNK")
displayTensor(b_NKNK)


z_MNMN_check = z_MN.splitUniform(M0).splitUniform(N0, depth=2).swapRanks(depth=1)
displayTensor(z_MN)
displayTensor(z_MNMN_check)

## Tiled spMspM (v1)

Dataflow for output-stationary tiles over output-stationary values


In [None]:
z_MNMN = Tensor(rank_ids=["M1", "N1", "M0", "N0"])
z_MNMN.setName("z_MNMN")

a_m1 = a_MKMK.getRoot()
b_n1 = b_NKNK.getRoot()
z_m1 = z_MNMN.getRoot()

canvas = createCanvas(a_MKMK, b_NKNK, z_MNMN)

for m1, (z_n1, a_k1) in z_m1 << a_m1:
    for n1, (z_m0, b_k1) in z_n1 << b_n1:
        for k1, (a_m0, b_n0) in a_k1 & b_k1:
            for m0, (z_n0, a_k0) in z_m0 << a_m0:
                for n0, (z_ref, b_k0) in z_n0 << b_n0:
                    for k0, (a_val, b_val) in a_k0 & b_k0:
                        z_ref += a_val * b_val
                        
                        # Show the currently active tiles
                        canvas.addActivity((m1, k1),
                                            (n1, k1),
                                            (m1, n1),
                                            worker="T")
                        
                        # Show the currently active elements
                        canvas.addFrame((m1, k1, m0, k0),
                                         (n1, k1, n0, k0),
                                         (m1, n1, m0, n0))

displayTensor(z_MNMN)
displayCanvas(canvas)

In [None]:
# Check that result is correct

z_MNMN.getRoot() == z_MNMN_check.getRoot()

## Tile the tensors (v2)

Pre-process the tensors into a 2-D tiled form (resulting in 4 ranks) in a rank order that is natural for A-stationary tiles over output-stationary values. We use the A operand from the previous examples, so we just need a new tiling of the B operand.

In [None]:
# We just need a new B tiling

b_KNNK = b_KN.splitUniform(K0).swapRanks(depth=1).splitUniform(N0, depth=1)
b_KNNK.setName("b_KNNK")
displayTensor(b_KNNK)


## Tiled spMspM (v2)

Dataflow of A-stationary tiles over output-stationary values

In [None]:
z_MNMN = Tensor(rank_ids=["M1", "N1", "M0", "N0"])
z_MNMN.setName("z_MNMN")

a_m1 = a_MKMK.getRoot()
b_k1 = b_KNNK.getRoot()
z_m1 = z_MNMN.getRoot()

canvas = createCanvas(a_MKMK, b_KNNK, z_MNMN)

for m1, (z_n1, a_k1) in z_m1 << a_m1:
    for k1, (a_m0, b_n1) in a_k1 & b_k1:
        for n1, (z_m0, b_n0) in z_n1 << b_n1:
            
            for m0, (z_n0, a_k0) in z_m0 << a_m0:
                for n0, (z_ref, b_k0) in z_n0 << b_n0:
                    for k0, (a_val, b_val) in a_k0 & b_k0:
                        z_ref += a_val * b_val
                                                                            
                        # Show the currently active tiles
                        canvas.addActivity((m1, k1),
                                            (k1, n1),
                                            (m1, n1),
                                            worker="T")
                        
                        # Show the currently active elements
                        canvas.addFrame((m1, k1, m0, k0),
                                         (k1, n1, n0, k0),
                                         (m1, n1, m0, n0))

displayTensor(z_MNMN)
displayCanvas(canvas)

In [None]:
# Check that result is correct

z_MNMN.getRoot() == z_MNMN_check.getRoot()