# Computing coordinates on Triton GEMM example

The tutorial is available here: https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py

The most important part is the disrtibution of the work accross the GPU.  
Triton has a concept of `program` which are basically the smallest units of work.  

The order in which these programs are executed has a direct impact on the performance of the computation.  
In the matmul example, `triton` authors introduce the notion of groups of programs.  
It's an optimized way to execute the programs in parallel where many program will share the same memory access.  
Memory access is the most expensive operation in the kernel execution, often much more than the computation itself.  

The part of the tutorial we are interested in:

```python
# program ID
pid = tl.program_id(axis=0)
# number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# id of the group this program is in
group_id = pid // num_pid_in_group
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *within groups*, programs are ordered in a column-major order
# row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
```

We will try to redo the computation in `numpy` to see what it means.  


reminder Gaetan explanations:

```python
# Le but est de grouper les processus pour que un des parcours (M ou N) soit limité à une zone précise
# On choisit donc des petits groupes sur M de taille GROUP_SIZE_M
# Pour chacun de ces groupes on va parcourir tous les block sur la dimension N mais on va reste dans le groupe pour la dimension M

# On note le nombre de programmes qui vont s'executer sur chaque axe
num_pid_m = cdiv(M, BLOCK_SIZE_M)
num_pid_n = cdiv(N, BLOCK_SIZE_N)

# On a GROUP_SIZE_M programme sur M qui vont s'executer pour chacun des num_pid_n process sur l'axe N
# Donc on va attendre num_pid_in_group instance avant de passer au groupe suivant
num_pid_in_group = GROUP_SIZE_M * num_pid_n

# ça permet que les premier num_pid_in_group soient dans le groupe 1, etc...
group_id = pid // num_pid_in_group

# A ce stade on a notre groupe maintenant faut  se positionner sur M et N
# Le premier id dans le groupe, sachant que l'id actuel est > à ce nombre
first_pid_m = group_id * GROUP_SIZE_M

# C'est si le process sur M est pas divisible par GROUP_SIZE_M
# Par exemple 15 process et des groupes de 8
# group_size_m = 15 - 8 = 7 et on a bien que 7 process dans le second groupe
# Pour le premier groupe  15 - 0 = 15 et GROUP_SIZE_M est 8, et on a bien 8 process dans le premier groupe
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *within groups*, programs are ordered in a column-major order
# row-id of the program in the *launch grid*
# 
# DOnc en gros on first_pid_m qui est le premier id de ntore groupe et on a notre id actuel qui est supérieur à ce nombre
# Faut maintenant se prosision sur une grille de taille GROUP_SIZE_M x num_pid_n
# On veut faire varier M d'abord donc on va parcourir la dimension GROUP_SIZE_M dans notre groupe
pid_m = first_pid_m + (pid % group_size_m)
# pendant qu'on fait varier sur l'axe GROUP_SIZE_M on peut que ça bouge pas sur N (pid % num_pid_in_group) c'est l'index relative au groupe
# et avec // group_size_m du coupe pour les group_size_m premier on sera à 0 sur N, etc...
pid_n = (pid % num_pid_in_group) // group_size_m
```

In [1]:
import numpy as np

def cdiv(x, y):
    """
    Ceiling division returns the closest integer greater than or equal to the quotient.
    """
    return (x + y - 1) // y


M = 1024
N = 768
K = 128

print(f"C [{M}x{N}] = A [{M}x{K}] * B [{K}x{N}]")

a = np.random.rand(M, K)
b = np.random.rand(K, N)
c = np.random.rand(M, N)

# tile sizes
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 2

num_pid_m = cdiv(M, BLOCK_SIZE_M)  # number of programs in M dimension
num_pid_n = cdiv(N, BLOCK_SIZE_N)  # number of programs in N dimension

print("num_pid_m:", num_pid_m)
print("num_pid_n:", num_pid_n)
print("num_pid_n * GROUP_SIZE_M =", num_pid_n * GROUP_SIZE_M)


C [1024x768] = A [1024x128] * B [128x768]
num_pid_m: 8
num_pid_n: 12
num_pid_n * GROUP_SIZE_M = 24


`Program` is the smallest unit of work in `Triton`.  
They may be executed in parallel, and the order in which they are executed is important for performances.  

Indeed, a key element for the performance is to limit the memory accesses as it is one of the most expensive operation
in the kernel execution (often more than the computation itself).  
For that purpose we will try to execute in parallel programs that share memory accesses.

Each program has a unique id, which is the index of the program in the list of programs.  
As each program will  iterate over the `K` axis, the formula below will provide us the number of programs to launch:

```python
grid = lambda META: (
    triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
```

Below, we redo the computation with the variables defined above.

In [2]:
nb_programs = num_pid_m * num_pid_n  # number of programs to launch
print("nb programs to launch:", nb_programs)

nb programs to launch: 96


In [3]:
# some example program ID
pid = 60
assert pid < nb_programs, f"we will launch a {num_pid_m}x{num_pid_n}={nb_programs} grid of programs, pid={pid} is too big"

We need at least `num_pid_m` programs to cover the `M` axis (number of rows of matrix `A`) and `num_pid_n` programs
to cover the `N` axis (number of columns of matrix `B`).  
Each of those programs will have to iterate over the `K` axis of `A` and `B` matrices.  

The goal of our strategy is to reuse data as much as possible.  
For that, we will run in parallel several programs that need the same data from one of the input matrix.  

Each `program` consumes `A` and `B` matrices.  
We can reuse either data from `A` or `B`, we arbitrarily choose to reuse data from `B`.  
It means that we run several programs on different rows of `A` (axis `M`) matrix that consume the same column from `B` (axis `N`).

> reminder: each `program` is responsible to iterate over `K` axis

If we expresse that logic in pseudo code, it would look like a nested loop where the outer loop (axis `N` of `B` matrix) is mostly serially iterated  
 and the inner loop (axis `M` of `A` matrix) is parallelized.


<!-- TODO range are not true -->
```python
for pos_n in range(num_pid_n):  # serialized iteration
    for pos_m in range(num_pid_m):  # GROUP_SIZE_M programs in parallel
        # each program is associated with a position on the M and N axis and will iterate itself over the K axis
        do_work()
```


Each group of `GROUP_SIZE_M` programs (which are positioned on `GROUP_SIZE_M` different rows on the `M` axis of `A`) 
will consume same complete column of the `N` axis of `B` matrix before switching to the next group.  
We need `GROUP_SIZE_M * num_pid_n` programs to finish the computation of a single `C` matrix block.

In [4]:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
assert num_pid_n * GROUP_SIZE_M <= M
print(num_pid_in_group)


24


If we need `num_pid_in_group` programs to process a single `C` block, we can guess the group id of the current program:

In [5]:
group_id = pid // num_pid_in_group
print(group_id)

2


Now we will:
* compute the `pid` of the first program in our group;
* compute the real size of the group, we want to catch the case of the last group of the row and its dimension is inferior to the others.


In [6]:
# row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M

print("first_pid_m:", first_pid_m)
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
print("group_size_m:", group_size_m)

assert group_size_m > 0

first_pid_m: 4
group_size_m: 2


We know which group we are part of, so we just need to find our spot in this group.  
> As noted in the original comments of the tutorial *"within groups, programs are ordered in a column-major order 
> row-id of the program in the launch grid"*

In [7]:
pid_m = first_pid_m + (pid % group_size_m)
print("pid_m:", pid_m)

pid_n = (pid % num_pid_in_group) // group_size_m
print("pid_n:", pid_n)

pid_m: 4
pid_n: 6


To finish, we need some simple broadcasted operations to convert our `pid` to an array of offsets we will read:

In [8]:
a_ptr, _ = a.__array_interface__['data']
b_ptr, _ = b.__array_interface__['data']
c_ptr, _ = c.__array_interface__['data']

    
stride_am, stride_ak =  a.strides
stride_bk, stride_bn =  b.strides
stride_cm, stride_cn =  c.strides


offs_am = pid_m * BLOCK_SIZE_M + np.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + np.arange(0, BLOCK_SIZE_N)
offs_k = np.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)