# Tiled matmul

In this simple tutorial we will see how to perform a `matmul` with tiling.
Tiling is a technique based on matrix partition, each block is called a tile.

With tiling, `matmul`:
* computation can be performed in parallel, a domain where GPUs excels;
* global memory (GM) access are limited, GM access being the GPU bottleneck (compared to computation).


In [1]:
import torch

M, N, K = 15, 9, 12

A = torch.rand((M, K))
B = torch.rand((K, N))

# Simple matmul with tiling

Simple example showing how we can perform a `matmul` through tiling.

Basic introduction to the subject can be found here:

* https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
* https://penny-xu.github.io/blog/tiled-matrix-multiplication

Parallelization can be applied at each `M` and `N` for loop levels.
However, best use of global memory access requires to be a bit smarter.
Check our dedicated explanation in tutorials.

Values used below are a arbitrary and small to be printable if needed.
Rule of thumb in defining tile shape is:
* large tile size increase data reuse, but decrease thread-level parallelism;
* small tile size increase thread-level parallelism but reduce data reuse.

In [2]:
# for simplification tile shapes are all multiple of matrix shapes
# otherwise we would need to check matrix bounds and mask out of bounds values by 0s in tiles
block_M, block_N, block_K = M // 3, N // 3, K // 2

output = torch.zeros((M, N))
total_load = 0
total_write = 0

for index_M in range(0, M, block_M):
    start_M = index_M
    end_M = index_M + block_M

    for index_N in range(0, N, block_N):
        start_N = index_N
        end_N = index_N + block_N
        accumulator = torch.zeros((block_M, block_N))
        for index_K in range(0, K, block_K):
            start_K = index_K
            end_K = index_K + block_K

            tile_A = A[start_M:end_M, start_K:end_K]
            total_load += tile_A.numel()
            tile_B = B[start_K:end_K, start_N:end_N]
            total_load += tile_B.numel()
            # @ means matmul in numpy and pytorch
            accumulator += tile_A @ tile_B
        output[start_M:end_M, start_N:end_N] = accumulator
        total_write += accumulator.numel()

assert torch.allclose(output, A @ B)
print("total load from GM:", total_load)
print("total write to GM:", total_write)

total load from GM: 864
total write to GM: 135


In the code above, we have tracked the quantity of global memory (GPU DRAM) load and write.
Guessing the quantity of data written is quite obvious, it's the number of elements inside the output matrix, so
`MxN`.

In [3]:
M * N

135

Regarding the loading, it is: (tile A shape + tile B shape) repeated on each `M`, `N`, `K` axis.

In [4]:
((block_M * block_K) + (block_K * block_N)) * (K / block_K) * (N / block_N) * (M / block_M)

864.0

You can note that if you make `block_N` and `block_M` smaller, it will increase the number of readings.