# Better matmul

Source: [Triton matmul](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py)

We can implement a better version of matmul, that operates on blocks instead of individual rows/columns.  This will reduce sparsity in each thread block. It's not the absolute most efficient way to do it (we need to group the rows to do that), but it's getting closer.

This uses the pseudocode below.  The idea is to load in an `MxK` block of matrix `A`, and a `KxN` block of matrix `B`.  Then you increment `K` to multiply across all rows and columns, then sum the results to get the final matmul results, which you store back to `C`.

```
# Do in parallel
for m in range(0, M, BLOCK_SIZE_M):
  # Do in parallel
  for n in range(0, N, BLOCK_SIZE_N):
    acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
    for k in range(0, K, BLOCK_SIZE_K):
      a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
      b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
      acc += dot(a, b)
    C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
```

We can define our launch function.  This will use a 2d launch grid of shape `(X.shape[0] / BLOCK_SIZE_X, Y.shape[1] / BLOCK_SIZE_Y)`.  It will iterate across a group of rows, and multiply by the corresponding group of columns to generate the output.

The minimum block size is `16` with triton.

In [1]:
import torch
from torch.testing import assert_close

import triton
import triton.language as tl

def matmul(X, Y):
    x_rows, x_cols = X.shape
    y_rows, y_cols = Y.shape
    output = torch.zeros(x_rows, y_cols, device="cuda") # Output matrix
    
    BLOCK_SIZE_X = 16
    BLOCK_SIZE_Y = 16
    BLOCK_SIZE_K = 32
    
    # Create a 2-D grid to iterate across rows and columns
    grid = lambda meta: (triton.cdiv(x_rows, meta["BLOCK_SIZE_X"]), triton.cdiv(y_cols, meta["BLOCK_SIZE_Y"])) #2-d launch grid where we iterate across rows and columns
    
    matmul_kernel[grid](X, Y, output, x_rows, x_cols, y_rows, y_cols, BLOCK_SIZE_X=BLOCK_SIZE_X, BLOCK_SIZE_Y=BLOCK_SIZE_Y, BLOCK_SIZE_K=BLOCK_SIZE_K)
    
    return output

We can now define the kernel.  We can use arrays to generate our pointers to reference blocks of data instead of individual rows/columns.  The main tricky part is incrementing the row/column pointer index properly.

This is an example of using an array to generate pointers: `x_ptrs = x_ptr + (x_offsets[:,None] + k_offsets[None,:])` .  We use `x_offsets[:, None]` to index rows, and `k_offsets[None,:]` to index columns:

- `x_offsets[:,None]` results in an `Nx1` array.
- `k_offsets[None,:]` results in a `1xM` array.

Adding them gives us an `NxM` array with the proper memory pointers for a block of data.

Here's an example - we want to grab the top left block in A:

In [8]:
# Example of addressing a memory block
# Initialize 8x8 matrix
A = torch.rand((8,8))

# Row indices
# Multiply by 8 since that is the width (in columns) of each row
# So each new row starts at old_row_start + 8
a = (torch.arange(4) * 8).reshape(-1,1)
# Column indices for first 4 columns
b = torch.arange(4).reshape(1,-1)

# Indices for top left corner of A
a + b

tensor([[ 0,  1,  2,  3],
        [ 8,  9, 10, 11],
        [16, 17, 18, 19],
        [24, 25, 26, 27]])

Now, we can write the kernel. We have two program ids, for the `X` matrix and the `Y` matrix.  Each program id refers to a block of rows/columns.

We pull out the rows from `X` and the columns from `Y` based on pid, in increments of `BLOCK_SIZE_X` and `BLOCK_SIZE_Y`.

We then iterate across the `K` dimension.  We will have `ceil(X.shape[1] / K)` groups to process:
- Select the first `K` columns from the `X` block, and the first `K` rows from the `Y` block.
- Multiply them, and store them in an accumulator.
- Continue iterating until the entire block of rows is multiplied by the entire block of columns.

We then put the result in the correct position in an output array.  We have to be careful when masking to mask the end of every row/column, because we're selecting whole blocks at a time.

The below program:

- Loads the `X` and `Y` blocks, with correct masking
- Iterates across the matrix inner dimension and multiplies the subsets
- Adds everything together in the accumulator
- Writes the output to the correct position in the output matrix

In [14]:
@triton.jit
def matmul_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    x_rows,
    x_cols,
    y_rows,
    y_cols,
    BLOCK_SIZE_X: tl.constexpr, # Row count per block
    BLOCK_SIZE_Y: tl.constexpr, # column count per block
    BLOCK_SIZE_K: tl.constexpr # Inner dim to iterate over (count per iteration)
):
    # Program ids from the 2d grid
    x_pid = tl.program_id(axis=0)
    y_pid = tl.program_id(axis=1)

    # Define the start position for the x and y pointers
    # Remember that we have to stride across the columns when incrementing rows, and vice versa
    x_row_start = x_pid * BLOCK_SIZE_X * x_cols # Start of the block we're selecting in x
    y_col_start = y_pid * BLOCK_SIZE_Y # Start of block in y

    # Get the row and column offsets.  Row offsets need to be multiplied by the number of columns in the matrix
    x_offsets = x_row_start + tl.arange(0, BLOCK_SIZE_X) * x_cols # Get row start index for each row (that's why we multiply by x_cols)
    y_offsets = y_col_start + tl.arange(0, BLOCK_SIZE_Y) # Get column start index for each column (stride is 1)

    # Get the k offsets, for the matrix inner dimension
    k_offsets = tl.arange(0, BLOCK_SIZE_K)

    # Define our x pointers, which will be from column 0 to k within each row
    x_ptrs = x_ptr + (x_offsets[:,None] + k_offsets[None,:])

    # Define our y pointers, which will be from 0 to k within each column
    # We multiply the k offsets by y_cols to get the row start positions
    y_ptrs = y_ptr + (k_offsets[:,None] * y_cols + y_offsets[None,:])

    # The accumulator stores the results as we iterate across k
    # Store in float32 for better numerical precision
    accumulator = tl.zeros((BLOCK_SIZE_X, BLOCK_SIZE_Y), dtype=tl.float32)
    # Iterate across k, increment by BLOCK_SIZE_K
    for k in range(0, tl.cdiv(x_cols, BLOCK_SIZE_K)):
        # Load the x subset to multiply
        # We mask to avoid loading anything beyond the end of each column
        # [None, :] adds a 1-length first dimension, so the mask can broadcast across x_ptrs
        a = tl.load(x_ptrs, mask=k_offsets[None, :] < x_cols - k * BLOCK_SIZE_K, other=0.0)

        # [:, None] adds a 1-length second dimension, so the mask can broadcast across y_ptrs
        b = tl.load(y_ptrs, mask=k_offsets[:,None] < y_rows - k * BLOCK_SIZE_K, other=0.0)

        # Multiply a and b, then add to the accumulator
        result = tl.dot(a, b)
        accumulator += result

        # Increment the x pointers to go across the rows
        x_ptrs += BLOCK_SIZE_K

        # Increment the y pointers to go down the columns - we need to multiply by y_cols because we're moving down the columns (across the rows)
        y_ptrs += BLOCK_SIZE_K * y_cols

    output = accumulator.to(tl.float16)

    # Find the output pointer positions
    output_x_start = x_pid * BLOCK_SIZE_X
    output_y_start = y_pid * BLOCK_SIZE_Y

    # This is how many rows down our output will start
    output_x_rows = output_x_start + tl.arange(0, BLOCK_SIZE_X)

    # Calculate output offsets
    output_x_offsets = output_x_start + tl.arange(0, BLOCK_SIZE_X)
    output_y_offsets = output_y_start + tl.arange(0, BLOCK_SIZE_Y)

    # Output pointers - note that the x offsets need to be multiplied by y_cols to convert from row numbers into row pointers
    output_ptrs = output_ptr + (output_x_offsets[:, None] * y_cols + output_y_offsets[None, :])

    # Store the data, ensuring we don't overflow the rows/columns
    # Output mask ensures we don't write anything outside of the matrix boundaries
    output_mask = (output_x_rows[:, None] < x_rows) & (output_y_offsets[None,:] < y_cols)
    tl.store(output_ptrs, output, mask=output_mask)

We can now test the kernel:

In [17]:
torch.manual_seed(0)
x = torch.rand((32, 64), device='cuda', dtype=torch.float16)
y = torch.rand((64, 32), device='cuda', dtype=torch.float16)
output_torch = x @ y
output_triton = matmul(x, y)
assert_close(output_triton, output_torch)
print(output_torch[:5,:5])
print(output_triton[:5,:5])

The maximum difference between torch and triton is 6.995553016662598
tensor([[3.3329, 4.2105, 3.4219, 3.6030, 4.6367, 3.0829, 3.5959, 4.0361, 4.3028,
         4.3413, 4.0192, 3.3523, 4.4383, 4.5726, 3.9396, 5.3161],
        [3.9074, 4.2839, 4.1713, 4.1455, 5.0679, 3.8858, 3.5398, 4.4548, 4.3134,
         4.2419, 4.5389, 3.7586, 4.5478, 4.7638, 3.9674, 5.6322],
        [3.6504, 4.1908, 4.2282, 3.1439, 4.4771, 4.5377, 4.1929, 4.2981, 3.7594,
         4.0305, 4.1960, 3.6005, 4.2225, 4.4191, 3.0318, 4.6951],
        [3.7793, 4.2255, 3.6918, 3.2550, 4.5282, 4.0765, 4.0407, 3.7886, 4.0163,
         4.0619, 4.0634, 3.6067, 4.6882, 4.7348, 3.6913, 5.3068],
        [4.6634, 5.1013, 4.2930, 3.9466, 5.3071, 4.6393, 4.4717, 4.4301, 4.7710,
         4.3534, 5.2003, 3.6226, 4.8308, 5.0126, 3.5417, 6.3912],
        [4.1602, 4.5921, 3.8755, 4.0547, 4.9854, 4.0085, 4.2527, 4.6740, 4.6540,
         4.6436, 4.4359, 4.1250, 5.1606, 4.7144, 3.7987, 5.8863],
        [3.6933, 5.3607, 4.3754, 4.1718, 5.5626, 