# Computing coordinates on Triton GEMM example

## Introduction

The original tutorial is available here:
https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py

The most important part of the tutorial is the distribution of the work across streaming multiprocessors (`SM`) of the GPU.  
`Triton` has a concept of `program` which are basically the smallest unit of work.  

The order in which these `program`s are executed has a direct impact on the performance of the computation.  

In the matmul example, `triton` authors introduce the notion of groups of `program`s.  
It's an optimized way to execute the `program`s (partially) in parallel where several `program`s will share the same memory access.  

> Memory access is the most expensive operation in the kernel execution, often much more than the computation itself.  

It's only "partially" in parallel because `GPU`s are not always big enough to process large matrices in a single step.  
That is why we need to split the matrix into smaller pieces and process only some pieces in parallel.

The part of the tutorial we are interested in:

```python
# program ID
pid = tl.program_id(axis=0)
# number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# id of the group this program is in
group_id = pid // num_pid_in_group
# row id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *within groups*, programs are ordered in a column-major order
# row id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
```

## Initialization

We will try to redo the computation in `torch` to see what `triton` code above means.

> Reminder: GEMM problem are usually presented as `MNK` problem where matrix `A` is `MxK` (input), matrix `B` is `KxN` 
> (input) and matrix `C` is `MxN` (output).

In [1]:
import torch


def cdiv(x, y):
    """
    Ceiling division returns the closest integer greater than or equal to the quotient.
    """
    return (x + y - 1) // y


M = 1024
N = 768
K = 128

print(f"C [{M}x{N}] = A [{M}x{K}] * B [{K}x{N}]")

a = torch.rand(M, K)
b = torch.rand(K, N)
c = torch.rand(M, N)

# programs work at the tile level, below are their dimensions for each axis
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 32

# group is a special concept to speed up matmul in this tutorial
GROUP_SIZE_M = 2

# # of programs to run on each C axis (each program will iterate over K)
num_pid_m = cdiv(M, BLOCK_SIZE_M)  # number of `program`s in M dimension, rounded to the nearest bigger integer
num_pid_n = cdiv(N, BLOCK_SIZE_N)  # number of `program`s in N dimension, rounded to the nearest bigger integer

print("num_pid_m:", num_pid_m)
print("num_pid_n:", num_pid_n)
print("num_pid_n * GROUP_SIZE_M =", num_pid_n * GROUP_SIZE_M)

C [1024x768] = A [1024x128] * B [128x768]
num_pid_m: 8
num_pid_n: 12
num_pid_n * GROUP_SIZE_M = 24


The main trick to limit global memory (`GM` aka the DDRAM) accesses is to have parallel programs that require the same part of the matrix.

Each `program` has a unique id, which is the index of the program in the list of `program`s.  
As each `program` will  iterate over the `K` axis, the formula below (from the tutorial) will provide us the total number of `program`s to launch:

```python
grid = lambda META: (
    triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
```

Below, we redo the computation in `torch`with the variables defined above.

In [2]:
nb_programs = num_pid_m * num_pid_n  # number of programs to launch
print("nb programs to launch:", nb_programs)

nb programs to launch: 96


In [3]:
# some example program ID
pid = 60
assert (
    pid < nb_programs
), f"we will launch a {num_pid_m}x{num_pid_n}={nb_programs} grid of programs, pid={pid} is too big"

We need at least `num_pid_m` `program`s to cover the `M` axis (number of rows of matrix `A`) and `num_pid_n` `program`s
to cover the `N` axis (number of columns of matrix `B`).  
Each of those `program`s will have to iterate over the `K` axis of `A` and `B` matrices.  

The goal of our strategy is to reuse data as much as possible.  
For that, we will run in parallel several `program`s that need the same data from one of the input matrix.  

Each `program` consumes `A` and `B` matrices.  
We can reuse either data from `A` or `B`, we arbitrarily choose to reuse data from `B`.  
It means that we run several `program`s on different rows of `A` (axis `M`) matrix that consume the same column from `B` (axis `N`).  

> reminder: each `program` is responsible to iterate over `K` axis

If we express that logic in pseudo-code, it would look like a nested loop where the outer loop (axis `N` of `B` matrix) is mostly serially iterated  
 and the inner loop (axis `M` of `A` matrix) is parallelized.

<!-- TODO range are not true -->
```python
for pos_n in range(num_pid_n):  # serialized iteration
    for pos_m in range(num_pid_m):  # GROUP_SIZE_M programs in parallel
        # each program is associated with a position on the M and N axis and will iterate itself over the K axis
        do_work()
```

Each group of `GROUP_SIZE_M` `program`s (which are positioned on `GROUP_SIZE_M` different rows on the `M` axis of `A`) 
will consume the same complete column of the `N` axis of `B` matrix before switching to the next group.  
We need `GROUP_SIZE_M * num_pid_n` `program`s to finish the computation of a single `C` matrix block.

In [4]:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
assert num_pid_n * GROUP_SIZE_M <= M
print(num_pid_in_group)

24


If we need `num_pid_in_group` `program`s to process a single `C` block, we can guess the group id of the current program:

In [5]:
group_id = pid // num_pid_in_group
print(group_id)

2


Now we will:
* compute the `pid` of the first program in our group;
* compute the real size of the group, we want to catch the case of the last group of the row when its dimension is inferior to the others.


In [6]:
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M

print("first_pid_m:", first_pid_m)
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
print("group_size_m:", group_size_m)

assert group_size_m > 0

first_pid_m: 4
group_size_m: 2


We know which group we are part of, so we just need to find our spot in this group.  
> As noted in the original comments of the tutorial *"within groups, programs are ordered in a column-major order 
> row-id of the program in the launch grid"*

In [7]:
pid_m = first_pid_m + (pid % group_size_m)
print("pid_m:", pid_m)

pid_n = (pid % num_pid_in_group) // group_size_m
print("pid_n:", pid_n)

pid_m: 4
pid_n: 6


To conclude this explanation, we will see how `pid_m` and `pid_n` are used to generate memroy offsets to read / write.

> Code below doesn't make sense in the `triton` context and are just here for the sake of completeness.


In [8]:
# torch semantic to retrieve data array pointer
a_ptr = a.storage().data_ptr()
b_ptr = b.storage().data_ptr()
c_ptr = c.storage().data_ptr()

# stride is the memory offset between two consecutive rows / columns
stride_am, stride_ak = a.stride()
stride_bk, stride_bn = b.stride()
stride_cm, stride_cn = c.stride()

# below we perform the conversion from starting pointer to a matrix of pointers
offs_am = pid_m * BLOCK_SIZE_M + torch.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + torch.arange(0, BLOCK_SIZE_N)
offs_k = torch.arange(0, BLOCK_SIZE_K)

# broadcasting is leveraged to compute the offsets
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

print("a_ptrs\n", a_ptrs)
print("b_ptrs\n", b_ptrs)

a_ptrs
 tensor([[75261632, 75261633, 75261634,  ..., 75261661, 75261662, 75261663],
        [75261760, 75261761, 75261762,  ..., 75261789, 75261790, 75261791],
        [75261888, 75261889, 75261890,  ..., 75261917, 75261918, 75261919],
        ...,
        [75277632, 75277633, 75277634,  ..., 75277661, 75277662, 75277663],
        [75277760, 75277761, 75277762,  ..., 75277789, 75277790, 75277791],
        [75277888, 75277889, 75277890,  ..., 75277917, 75277918, 75277919]])
b_ptrs
 tensor([[75722624, 75722625, 75722626,  ..., 75722685, 75722686, 75722687],
        [75723392, 75723393, 75723394,  ..., 75723453, 75723454, 75723455],
        [75724160, 75724161, 75724162,  ..., 75724221, 75724222, 75724223],
        ...,
        [75744896, 75744897, 75744898,  ..., 75744957, 75744958, 75744959],
        [75745664, 75745665, 75745666,  ..., 75745725, 75745726, 75745727],
        [75746432, 75746433, 75746434,  ..., 75746493, 75746494, 75746495]])
