# Tiled matmul

In this simple tutorial we will see how to perform a `matmul` with tiling.
Tiling is a technique based on matrix partition, each block is called a tile.

With tiling, `matmul`:
* computation can be performed in parallel, a domain where GPUs excels;
* global memory (GM) access are limited, GM access being the GPU bottleneck (compared to computation).

## GEMM introduction

Below we define the problem size and initialize the matrices.
In `GEMM` a problem is defined by 3 numbers: `KMN`.

`D = α * A * B + β * C` with:
* `D` shape is `MxN`
* `A` shape is `MxK`
* `B` shape is `KxN`
* `C` shape is `MxN`
* `α` and `β` are 2 constants

> for readability, below `α` is implicitly set to 1 and `β` to 0 so `C` do not appear.
> Obviously, re-introducing them would be very easy.


In [1]:
import torch

M, N0, K0 = 15, 9, 12

A0 = torch.rand((M, K0))
B0 = torch.rand((K0, N0))

# Simple matmul with tiling

Simple example showing how we can perform a `matmul` through tiling.

Basic introduction to the subject can be found here:

* https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
* https://penny-xu.github.io/blog/tiled-matrix-multiplication

Parallelization can be applied at each `M` and `N` for loop levels.
However, best use of global memory access requires to be a bit smarter.
Check our dedicated explanation in tutorials.

Values used below are a arbitrary and small to be printable if needed.
Rule of thumb in defining tile shape is:
* large tile size increase data reuse, but decrease thread-level parallelism;
* small tile size increase thread-level parallelism but reduce data reuse.

In [2]:
# for simplification tile shapes are all multiple of matrix shapes
# otherwise we would need to check matrix bounds and mask out of bounds values by 0s in tiles
block_M, block_N0, block_K0 = M // 3, N0 // 3, K0 // 3

accumulator0 = torch.zeros((M, N0))
for index_M in range(0, M, block_M):
    start_M = index_M
    end_M = index_M + block_M

    for index_N0 in range(0, N0, block_N0):
        start_N0 = index_N0
        end_N0 = index_N0 + block_N0

        for index_K0 in range(0, K0, block_K0):
            start_K0 = index_K0
            end_K0 = index_K0 + block_K0

            tile_A0 = A0[start_M:end_M, start_K0:end_K0]
            tile_B0 = B0[start_K0:end_K0, start_N0:end_N0]
            # @ means matmul in numpy and pytorch
            accumulator0[start_M:end_M, start_N0:end_N0] += tile_A0 @ tile_B0

assert torch.allclose(accumulator0, A0 @ B0)