# Fused tiled matmul

This is a follow-up of tiled `matmul` notebook.

Fused `matmul` helps in reducing memory access and makes things faster as the intermediate outputs are never materialized (written in global memory).

Introduction to the approach can be found there:
https://github.com/NVIDIA/cutlass/tree/master/examples/13_two_tensor_op_fusion

The main trick (and constrain) is to make one of the tile axis as big as one of the second matrix to multiply.
Therefore, there will be no need to iterate over this axis. It won't work if the tile is too big for the `SRAM`.

> In particular, if `Bx` is the right matrix in the second `matmul` operation, `n` axis matrix `bx` tiles length should be equal to `Nx` axis `Bx` matrix length (we follow the MKN axis names from GEMM).

It introduces a constraint on the length of this axis which needs to be small enough to be kept in shared memory / registries.

In the example below, we chain 2 `matmul`:

* A1 = A0 @ B0
* A2 = A1 @ B1

> our goal is to never materialize A1

In [1]:
import torch

M, N0, K0 = 15, 9, 12

A0 = torch.rand((M, K0))
B0 = torch.rand((K0, N0))

# block_Nx is always equal to Nx
# for simplification block tile shapes are all multiple of matrices shapes
block_M, block_N0, block_K0 = M // 3, N0, K0 // 3

# by definition K1 is always N0 as A1 is multiplied with B1 and A1 N axis is the one of B0
N1, K1 = 12, N0

# we iterate over N0 so block_K0 is always a multiple of block_K1 to avoid using masking, etc.
block_N1, block_K1 = N1, block_K0 // 2

# initialize B1 matrix
B1 = torch.rand((K1, N1))

Some important shapes:

* shape of `A1 = matmul(A0, B0)` is `MxN0`, iterate over `K0`
* shape of `B1` is `K1xN1` with `K1 == N0`
* shape of `A2 = matmul(A1, B1)` is `MxN1`, iterate over `K1`
  * because `K1 == N0`, during the second matmul we iterate over `N0`

So we will set the following tile shapes:
* `block_N0 = N0`
* `block_N1 = N1`

> In the code `block_N0` and `block_N1` instead of `:` are used for readability reasons

In [2]:
accumulator2 = torch.zeros((M, N1))

for index_M in range(0, M, block_M):
    start_M = index_M
    end_M = index_M + block_M
    for index_K0 in range(0, K0, block_K0):
        start_K0 = index_K0
        end_K0 = index_K0 + block_K0

        tile_A0 = A0[start_M:end_M, start_K0:end_K0]
        tile_B0 = B0[start_K0:end_K0, :block_N0]
        tile_A1 = tile_A0 @ tile_B0
        for index_K1 in range(0, K1, block_K1):
            start_K1 = index_K1
            end_K1 = index_K1 + block_K1

            tile_tile_A1 = tile_A1[:, start_K1:end_K1]
            tile_B1 = B1[start_K1:end_K1, :block_N1]

            accumulator2[start_M:end_M, :block_N1] += tile_tile_A1 @ tile_B1

assert torch.allclose(accumulator2, (A0 @ B0) @ B1)